Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ private void initializeLifecycleManager(String appId) {
if (lifecycleManager == null) {
appUniqueId = celebornConf.appUniqueIdWithUUIDSuffix(appId);
lifecycleManager = new LifecycleManager(appUniqueId, celebornConf);
lifecycleManager.registerCancelShuffleCallback(SparkUtils::cancelShuffle);
if (celebornConf.clientFetchThrowsFetchFailure()) {
MapOutputTrackerMaster mapOutputTracker =
(MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,19 @@
import java.lang.reflect.Method;
import java.util.concurrent.atomic.LongAdder;

import scala.Option;
import scala.Some;
import scala.Tuple2;

import org.apache.spark.BarrierTaskContext;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkContext$;
import org.apache.spark.TaskContext;
import org.apache.spark.scheduler.DAGScheduler;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.scheduler.ShuffleMapStage;
import org.apache.spark.sql.execution.UnsafeRowSerializer;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.storage.BlockManagerId;
Expand All @@ -39,6 +44,7 @@
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.reflect.DynFields;

public class SparkUtils {
private static final Logger logger = LoggerFactory.getLogger(SparkUtils.class);
Expand Down Expand Up @@ -179,4 +185,22 @@ public static void addFailureListenerIfBarrierTask(
shuffleClient.reportBarrierTaskFailure(appShuffleId, appShuffleIdentifier);
});
}

private static final DynFields.UnboundField shuffleIdToMapStage_FIELD =
DynFields.builder().hiddenImpl(DAGScheduler.class, "shuffleIdToMapStage").build();

public static void cancelShuffle(int shuffleId, String reason) {
if (SparkContext$.MODULE$.getActive().nonEmpty()) {
DAGScheduler scheduler = SparkContext$.MODULE$.getActive().get().dagScheduler();
scala.collection.mutable.Map<Integer, ShuffleMapStage> shuffleIdToMapStageValue =
(scala.collection.mutable.Map<Integer, ShuffleMapStage>)
shuffleIdToMapStage_FIELD.bind(scheduler).get();
Option<ShuffleMapStage> shuffleMapStage = shuffleIdToMapStageValue.get(shuffleId);
if (shuffleMapStage.nonEmpty()) {
scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
}
} else {
logger.error("Can not get active SparkContext, skip cancelShuffle.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ private void initializeLifecycleManager(String appId) {
if (lifecycleManager == null) {
appUniqueId = celebornConf.appUniqueIdWithUUIDSuffix(appId);
lifecycleManager = new LifecycleManager(appUniqueId, celebornConf);
lifecycleManager.registerCancelShuffleCallback(SparkUtils::cancelShuffle);
if (celebornConf.clientFetchThrowsFetchFailure()) {
MapOutputTrackerMaster mapOutputTracker =
(MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,20 @@

import java.util.concurrent.atomic.LongAdder;

import scala.Option;
import scala.Some;
import scala.Tuple2;

import org.apache.spark.BarrierTaskContext;
import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkContext$;
import org.apache.spark.TaskContext;
import org.apache.spark.scheduler.DAGScheduler;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.scheduler.ShuffleMapStage;
import org.apache.spark.shuffle.ShuffleHandle;
import org.apache.spark.shuffle.ShuffleReadMetricsReporter;
import org.apache.spark.shuffle.ShuffleReader;
Expand Down Expand Up @@ -266,6 +271,9 @@ public static <K, C> CelebornShuffleReader<K, C> createColumnarShuffleReader(
.orNoop()
.build();

private static final DynFields.UnboundField shuffleIdToMapStage_FIELD =
DynFields.builder().hiddenImpl(DAGScheduler.class, "shuffleIdToMapStage").build();

public static void unregisterAllMapOutput(
MapOutputTrackerMaster mapOutputTracker, int shuffleId) {
if (!UnregisterAllMapAndMergeOutput_METHOD.isNoop()) {
Expand Down Expand Up @@ -296,4 +304,19 @@ public static void addFailureListenerIfBarrierTask(
shuffleClient.reportBarrierTaskFailure(appShuffleId, appShuffleIdentifier);
});
}

public static void cancelShuffle(int shuffleId, String reason) {
if (SparkContext$.MODULE$.getActive().nonEmpty()) {
DAGScheduler scheduler = SparkContext$.MODULE$.getActive().get().dagScheduler();
scala.collection.mutable.Map<Integer, ShuffleMapStage> shuffleIdToMapStageValue =
(scala.collection.mutable.Map<Integer, ShuffleMapStage>)
shuffleIdToMapStage_FIELD.bind(scheduler).get();
Option<ShuffleMapStage> shuffleMapStage = shuffleIdToMapStageValue.get(shuffleId);
if (shuffleMapStage.nonEmpty()) {
scheduler.cancelStage(shuffleMapStage.get().id(), new Some<>(reason));
}
} else {
LOG.error("Can not get active SparkContext, skip cancelShuffle.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.client.MasterClient
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.protocol.PbReviseLostShufflesResponse
import org.apache.celeborn.common.protocol.message.ControlMessages.{ApplicationLost, ApplicationLostResponse, HeartbeatFromApplication, HeartbeatFromApplicationResponse, ReviseLostShuffles, ZERO_UUID}
import org.apache.celeborn.common.protocol.message.ControlMessages.{ApplicationLost, ApplicationLostResponse, CheckQuotaResponse, HeartbeatFromApplication, HeartbeatFromApplicationResponse, ReviseLostShuffles, ZERO_UUID}
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.util.{ThreadUtils, Utils}

Expand All @@ -39,7 +39,8 @@ class ApplicationHeartbeater(
masterClient: MasterClient,
shuffleMetrics: () => (Long, Long),
workerStatusTracker: WorkerStatusTracker,
registeredShuffles: ConcurrentHashMap.KeySetView[Int, java.lang.Boolean]) extends Logging {
registeredShuffles: ConcurrentHashMap.KeySetView[Int, java.lang.Boolean],
cancelAllActiveStages: String => Unit) extends Logging {

private var stopped = false
private val reviseLostShuffles = conf.reviseLostShufflesEnabled
Expand Down Expand Up @@ -77,6 +78,7 @@ class ApplicationHeartbeater(
if (response.statusCode == StatusCode.SUCCESS) {
logDebug("Successfully send app heartbeat.")
workerStatusTracker.handleHeartbeatResponse(response)
checkQuotaExceeds(response.checkQuotaResponse)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a flag to tell spark whether cancel the stage when quota exceeds?

// revise shuffle id if there are lost shuffles
if (reviseLostShuffles) {
val masterRecordedShuffleIds = response.registeredShuffles
Expand Down Expand Up @@ -132,7 +134,8 @@ class ApplicationHeartbeater(
List.empty.asJava,
List.empty.asJava,
List.empty.asJava,
List.empty.asJava)
List.empty.asJava,
CheckQuotaResponse(isAvailable = true, ""))
}
}

Expand All @@ -149,6 +152,12 @@ class ApplicationHeartbeater(
}
}

private def checkQuotaExceeds(response: CheckQuotaResponse): Unit = {
if (conf.quotaInterruptShuffleEnabled && !response.isAvailable) {
cancelAllActiveStages(response.reason)
}
}

def stop(): Unit = {
stopped.synchronized {
if (!stopped) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import java.nio.ByteBuffer
import java.security.SecureRandom
import java.util
import java.util.{function, List => JList}
import java.util.concurrent.{Callable, ConcurrentHashMap, LinkedBlockingQueue, ScheduledFuture, TimeUnit}
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicInteger
import java.util.function.Consumer
import java.util.function.{BiConsumer, Consumer}

import scala.collection.JavaConverters._
import scala.collection.generic.CanBuildFrom
Expand Down Expand Up @@ -211,7 +211,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
masterClient,
() => commitManager.commitMetrics(),
workerStatusTracker,
registeredShuffle)
registeredShuffle,
reason => cancelAllActiveStages(reason))
private val changePartitionManager = new ChangePartitionManager(conf, this)
private val releasePartitionManager = new ReleasePartitionManager(conf, this)

Expand Down Expand Up @@ -1763,6 +1764,11 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
appShuffleDeterminateMap.put(appShuffleId, determinate)
}

@volatile private var cancelShuffleCallback: Option[BiConsumer[Integer, String]] = None
def registerCancelShuffleCallback(callback: BiConsumer[Integer, String]): Unit = {
cancelShuffleCallback = Some(callback)
}

// Initialize at the end of LifecycleManager construction.
initialize()

Expand All @@ -1781,4 +1787,16 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
rnd.nextBytes(secretBytes)
JavaUtils.bytesToString(ByteBuffer.wrap(secretBytes))
}

def cancelAllActiveStages(reason: String): Unit = cancelShuffleCallback match {
case Some(c) =>
shuffleAllocatedWorkers
.asScala
.keys
.filter(!commitManager.isStageEnd(_))
.foreach(c.accept(_, reason))

case _ =>
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ class WorkerStatusTrackerSuite extends CelebornFunSuite {
unknownWorkers,
shuttingWorkers,
availableWorkers,
new util.ArrayList[Integer]())
new util.ArrayList[Integer](),
null)
}

private def mockWorkers(workerHosts: Array[String]): util.ArrayList[WorkerInfo] = {
Expand Down
1 change: 1 addition & 0 deletions common/src/main/proto/TransportMessages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ message PbHeartbeatFromApplicationResponse {
repeated PbWorkerInfo shuttingWorkers = 4;
repeated int32 registeredShuffles = 5;
repeated PbWorkerInfo availableWorkers = 6;
PbCheckQuotaResponse checkQuotaResponse = 7;
}

message PbCheckQuota {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def quotaIdentityProviderClass: String = get(QUOTA_IDENTITY_PROVIDER)
def quotaUserSpecificTenant: String = get(QUOTA_USER_SPECIFIC_TENANT)
def quotaUserSpecificUserName: String = get(QUOTA_USER_SPECIFIC_USERNAME)
def quotaInterruptShuffleEnabled: Boolean = get(QUOTA_INTERRUPT_SHUFFLE_ENABLED)

// //////////////////////////////////////////////////////
// Client //
Expand Down Expand Up @@ -5333,6 +5334,14 @@ object CelebornConf extends Logging {
.longConf
.createWithDefault(Long.MaxValue)

val QUOTA_INTERRUPT_SHUFFLE_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.quota.interruptShuffle.enabled")
.categories("quota", "client")
.version("0.6.0")
.doc("Whether to enable interrupt shuffle when quota exceeds.")
.booleanConf
.createWithDefault(false)

val COLUMNAR_SHUFFLE_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.columnarShuffle.enabled")
.withAlternative("celeborn.columnar.shuffle.enabled")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,8 @@ object ControlMessages extends Logging {
unknownWorkers: util.List[WorkerInfo],
shuttingWorkers: util.List[WorkerInfo],
availableWorkers: util.List[WorkerInfo],
registeredShuffles: util.List[Integer]) extends Message
registeredShuffles: util.List[Integer],
checkQuotaResponse: CheckQuotaResponse) extends Message

case class CheckQuota(userIdentifier: UserIdentifier) extends Message

Expand Down Expand Up @@ -832,7 +833,10 @@ object ControlMessages extends Logging {
unknownWorkers,
shuttingWorkers,
availableWorkers,
registeredShuffles) =>
registeredShuffles,
checkQuotaResponse) =>
val pbCheckQuotaResponse = PbCheckQuotaResponse.newBuilder().setAvailable(
checkQuotaResponse.isAvailable).setReason(checkQuotaResponse.reason)
val payload = PbHeartbeatFromApplicationResponse.newBuilder()
.setStatus(statusCode.getValue)
.addAllExcludedWorkers(
Expand All @@ -844,6 +848,7 @@ object ControlMessages extends Logging {
.addAllAvailableWorkers(
availableWorkers.asScala.map(PbSerDeUtils.toPbWorkerInfo(_, true, true)).toList.asJava)
.addAllRegisteredShuffles(registeredShuffles)
.setCheckQuotaResponse(pbCheckQuotaResponse)
.build().toByteArray
new TransportMessage(MessageType.HEARTBEAT_FROM_APPLICATION_RESPONSE, payload)

Expand Down Expand Up @@ -1221,6 +1226,7 @@ object ControlMessages extends Logging {
case HEARTBEAT_FROM_APPLICATION_RESPONSE_VALUE =>
val pbHeartbeatFromApplicationResponse =
PbHeartbeatFromApplicationResponse.parseFrom(message.getPayload)
val pbCheckQuotaResponse = pbHeartbeatFromApplicationResponse.getCheckQuotaResponse
HeartbeatFromApplicationResponse(
Utils.toStatusCode(pbHeartbeatFromApplicationResponse.getStatus),
pbHeartbeatFromApplicationResponse.getExcludedWorkersList.asScala
Expand All @@ -1231,7 +1237,8 @@ object ControlMessages extends Logging {
.map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava,
pbHeartbeatFromApplicationResponse.getAvailableWorkersList.asScala
.map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava,
pbHeartbeatFromApplicationResponse.getRegisteredShufflesList)
pbHeartbeatFromApplicationResponse.getRegisteredShufflesList,
CheckQuotaResponse(pbCheckQuotaResponse.getAvailable, pbCheckQuotaResponse.getReason))

case CHECK_QUOTA_VALUE =>
val pbCheckAvailable = PbCheckQuota.parseFrom(message.getPayload)
Expand Down
1 change: 1 addition & 0 deletions docs/configuration/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ license: |
| celeborn.quota.identity.provider | org.apache.celeborn.common.identity.DefaultIdentityProvider | false | IdentityProvider class name. Default class is `org.apache.celeborn.common.identity.DefaultIdentityProvider`. Optional values: org.apache.celeborn.common.identity.HadoopBasedIdentityProvider user name will be obtained by UserGroupInformation.getUserName; org.apache.celeborn.common.identity.DefaultIdentityProvider user name and tenant id are default values or user-specific values. | 0.2.0 | |
| celeborn.quota.identity.user-specific.tenant | default | false | Tenant id if celeborn.quota.identity.provider is org.apache.celeborn.common.identity.DefaultIdentityProvider. | 0.3.0 | |
| celeborn.quota.identity.user-specific.userName | default | false | User name if celeborn.quota.identity.provider is org.apache.celeborn.common.identity.DefaultIdentityProvider. | 0.3.0 | |
| celeborn.quota.interruptShuffle.enabled | false | false | Whether to enable interrupt shuffle when quota exceeds. | 0.6.0 | |
| celeborn.storage.availableTypes | HDD | false | Enabled storages. Available options: MEMORY,HDD,SSD,HDFS. Note: HDD and SSD would be treated as identical. | 0.3.0 | celeborn.storage.activeTypes |
| celeborn.storage.hdfs.dir | &lt;undefined&gt; | false | HDFS base directory for Celeborn to store shuffle data. | 0.2.0 | |
| celeborn.storage.s3.access.key | &lt;undefined&gt; | false | S3 access key for Celeborn to store shuffle data. | 0.6.0 | |
Expand Down
1 change: 1 addition & 0 deletions docs/configuration/quota.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ license: |
| celeborn.quota.identity.provider | org.apache.celeborn.common.identity.DefaultIdentityProvider | false | IdentityProvider class name. Default class is `org.apache.celeborn.common.identity.DefaultIdentityProvider`. Optional values: org.apache.celeborn.common.identity.HadoopBasedIdentityProvider user name will be obtained by UserGroupInformation.getUserName; org.apache.celeborn.common.identity.DefaultIdentityProvider user name and tenant id are default values or user-specific values. | 0.2.0 | |
| celeborn.quota.identity.user-specific.tenant | default | false | Tenant id if celeborn.quota.identity.provider is org.apache.celeborn.common.identity.DefaultIdentityProvider. | 0.3.0 | |
| celeborn.quota.identity.user-specific.userName | default | false | User name if celeborn.quota.identity.provider is org.apache.celeborn.common.identity.DefaultIdentityProvider. | 0.3.0 | |
| celeborn.quota.interruptShuffle.enabled | false | false | Whether to enable interrupt shuffle when quota exceeds. | 0.6.0 | |
| celeborn.quota.tenant.diskBytesWritten | 9223372036854775807 | true | Quota dynamic configuration for written disk bytes. | 0.5.0 | |
| celeborn.quota.tenant.diskFileCount | 9223372036854775807 | true | Quota dynamic configuration for written disk file count. | 0.5.0 | |
| celeborn.quota.tenant.hdfsBytesWritten | 9223372036854775807 | true | Quota dynamic configuration for written hdfs bytes. | 0.5.0 | |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,8 @@ private[celeborn] class Master(
new util.ArrayList[WorkerInfo](
(statusSystem.shutdownWorkers.asScala ++ statusSystem.decommissionWorkers.asScala).asJava),
availableWorksSentToClient,
new util.ArrayList(appRelatedShuffles)))
new util.ArrayList(appRelatedShuffles),
CheckQuotaResponse(isAvailable = true, "")))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should actually check quota for the application.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently QuotaManager only support check quota for user, at next PR, we will support check quota for application, so we don’t need to modify rpc proto at this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

} else {
context.reply(OneWayMessageResponse)
}
Expand Down