Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,11 @@ public Properties cryptoConf() {
}

/**
* The max number of chunks allowed to being transferred at the same time on shuffle service.
* The max number of chunks allowed to be transferred at the same time on shuffle service.
* Note that new incoming connections will be closed when the max number is hit. The client will
* retry according to the shuffle retry configs (see `spark.shuffle.io.maxRetries` and
* `spark.shuffle.io.retryWait`), if those limits are reached the task will fail with fetch
* failure.
*/
public long maxChunksBeingTransferred() {
return conf.getLong("spark.shuffle.maxChunksBeingTransferred", Long.MAX_VALUE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ private class ExtendedChannelPromise extends DefaultChannelPromise {
private List<GenericFutureListener> listeners = new ArrayList<>();
private boolean success;

public ExtendedChannelPromise(Channel channel) {
ExtendedChannelPromise(Channel channel) {
super(channel);
success = false;
}
Expand All @@ -127,7 +127,9 @@ public void finish(boolean success) {
listeners.forEach(listener -> {
try {
listener.operationComplete(this);
} catch (Exception e) { }
} catch (Exception e) {
// do nothing
}
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,16 @@ final class ShuffleExternalSorter extends MemoryConsumer {
this.taskContext = taskContext;
this.numPartitions = numPartitions;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSizeBytes = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
this.fileBufferSizeBytes =
(int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
this.numElementsForSpillThreshold =
conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", 1024 * 1024 * 1024);
this.writeMetrics = writeMetrics;
this.inMemSorter = new ShuffleInMemorySorter(
this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true));
this.peakMemoryUsedBytes = getMemoryUsage();
this.diskWriteBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE());
this.diskWriteBufferSize =
(int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.io.IOException;
import java.util.LinkedList;
import java.util.Queue;
import java.util.function.Supplier;

import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
Expand All @@ -48,8 +49,16 @@ public final class UnsafeExternalSorter extends MemoryConsumer {

@Nullable
private final PrefixComparator prefixComparator;

/**
* {@link RecordComparator} may probably keep the reference to the records they compared last
* time, so we should not keep a {@link RecordComparator} instance inside
* {@link UnsafeExternalSorter}, because {@link UnsafeExternalSorter} is referenced by
* {@link TaskContext} and thus can not be garbage collected until the end of the task.
*/
@Nullable
private final RecordComparator recordComparator;
private final Supplier<RecordComparator> recordComparatorSupplier;

private final TaskMemoryManager taskMemoryManager;
private final BlockManager blockManager;
private final SerializerManager serializerManager;
Expand Down Expand Up @@ -90,14 +99,14 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter(
BlockManager blockManager,
SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
Supplier<RecordComparator> recordComparatorSupplier,
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
long numElementsForSpillThreshold,
UnsafeInMemorySorter inMemorySorter) throws IOException {
UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
serializerManager, taskContext, recordComparator, prefixComparator, initialSize,
serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize,
numElementsForSpillThreshold, pageSizeBytes, inMemorySorter, false /* ignored */);
sorter.spill(Long.MAX_VALUE, sorter);
// The external sorter will be used to insert records, in-memory sorter is not needed.
Expand All @@ -110,14 +119,14 @@ public static UnsafeExternalSorter create(
BlockManager blockManager,
SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
Supplier<RecordComparator> recordComparatorSupplier,
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
long numElementsForSpillThreshold,
boolean canUseRadixSort) {
return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager,
taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes,
taskContext, recordComparatorSupplier, prefixComparator, initialSize, pageSizeBytes,
numElementsForSpillThreshold, null, canUseRadixSort);
}

Expand All @@ -126,7 +135,7 @@ private UnsafeExternalSorter(
BlockManager blockManager,
SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
Supplier<RecordComparator> recordComparatorSupplier,
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
Expand All @@ -138,15 +147,24 @@ private UnsafeExternalSorter(
this.blockManager = blockManager;
this.serializerManager = serializerManager;
this.taskContext = taskContext;
this.recordComparator = recordComparator;
this.recordComparatorSupplier = recordComparatorSupplier;
this.prefixComparator = prefixComparator;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
// this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024
this.fileBufferSizeBytes = 32 * 1024;

if (existingInMemorySorter == null) {
RecordComparator comparator = null;
if (recordComparatorSupplier != null) {
comparator = recordComparatorSupplier.get();
}
this.inMemSorter = new UnsafeInMemorySorter(
this, taskMemoryManager, recordComparator, prefixComparator, initialSize, canUseRadixSort);
this,
taskMemoryManager,
comparator,
prefixComparator,
initialSize,
canUseRadixSort);
} else {
this.inMemSorter = existingInMemorySorter;
}
Expand Down Expand Up @@ -451,14 +469,14 @@ public void merge(UnsafeExternalSorter other) throws IOException {
* after consuming this iterator.
*/
public UnsafeSorterIterator getSortedIterator() throws IOException {
assert(recordComparator != null);
assert(recordComparatorSupplier != null);
if (spillWriters.isEmpty()) {
assert(inMemSorter != null);
readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
return readingIterator;
} else {
final UnsafeSorterSpillMerger spillMerger =
new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size());
final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger(
recordComparatorSupplier.get(), prefixComparator, spillWriters.size());
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
outStream.println(
s"""
|Options:
| --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local.
| --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local
| (Default: local[*]).
| --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or
| on one of the worker machines inside the cluster ("cluster")
| (Default: client).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ private UnsafeExternalSorter newSorter() throws IOException {
blockManager,
serializerManager,
taskContext,
recordComparator,
() -> recordComparator,
prefixComparator,
/* initialSize */ 1024,
pageSizeBytes,
Expand Down Expand Up @@ -440,7 +440,7 @@ public void testPeakMemoryUsed() throws Exception {
blockManager,
serializerManager,
taskContext,
recordComparator,
() -> recordComparator,
prefixComparator,
1024,
pageSizeBytes,
Expand Down
6 changes: 5 additions & 1 deletion docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,11 @@ Apart from these, the following properties are also available, and may be useful
<td><code>spark.shuffle.maxChunksBeingTransferred</code></td>
<td>Long.MAX_VALUE</td>
<td>
The max number of chunks allowed to being transferred at the same time on shuffle service.
The max number of chunks allowed to be transferred at the same time on shuffle service.
Note that new incoming connections will be closed when the max number is hit. The client will
retry according to the shuffle retry configs (see <code>spark.shuffle.io.maxRetries</code> and
<code>spark.shuffle.io.retryWait</code>), if those limits are reached the task will fail with
fetch failure.
</td>
</tr>
<tr>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
Expand All @@ -53,7 +54,8 @@ private[ml] trait ClassifierTypeTrait {
/**
* Params for [[OneVsRest]].
*/
private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {
private[ml] trait OneVsRestParams extends PredictorParams
with ClassifierTypeTrait with HasWeightCol {

/**
* param for the base binary classifier that we reduce multiclass classification into.
Expand Down Expand Up @@ -294,6 +296,18 @@ final class OneVsRest @Since("1.4.0") (
@Since("1.5.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)

/**
* Sets the value of param [[weightCol]].
*
* This is ignored if weight is not supported by [[classifier]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("2.3.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
Expand All @@ -317,7 +331,20 @@ final class OneVsRest @Since("1.4.0") (
val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity)
instr.logNumClasses(numClasses)

val multiclassLabeled = dataset.select($(labelCol), $(featuresCol))
val weightColIsUsed = isDefined(weightCol) && $(weightCol).nonEmpty && {
getClassifier match {
case _: HasWeightCol => true
case c =>
logWarning(s"weightCol is ignored, as it is not supported by $c now.")
false
}
}

val multiclassLabeled = if (weightColIsUsed) {
dataset.select($(labelCol), $(featuresCol), $(weightCol))
} else {
dataset.select($(labelCol), $(featuresCol))
}

// persist if underlying dataset is not persistent.
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
Expand All @@ -337,7 +364,13 @@ final class OneVsRest @Since("1.4.0") (
paramMap.put(classifier.labelCol -> labelColName)
paramMap.put(classifier.featuresCol -> getFeaturesCol)
paramMap.put(classifier.predictionCol -> getPredictionCol)
classifier.fit(trainingDataset, paramMap)
if (weightColIsUsed) {
val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol]
paramMap.put(classifier_.weightCol -> getWeightCol)
classifier_.fit(trainingDataset, paramMap)
} else {
classifier.fit(trainingDataset, paramMap)
}
}.toArray[ClassificationModel[_, _]]
instr.logNumFeatures(models.head.numFeatures)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,7 @@ private[r] object GeneralizedLinearRegressionWrapper
.setStringIndexerOrderType(stringIndexerOrderType)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
// get labels and feature names from output schema
val schema = rFormulaModel.transform(data).schema
val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
.attributes.get
val features = featureAttrs.map(_.name.get)

// assemble and fit the pipeline
val glr = new GeneralizedLinearRegression()
.setFamily(family)
Expand All @@ -113,37 +109,16 @@ private[r] object GeneralizedLinearRegressionWrapper
val summary = glm.summary

val rFeatures: Array[String] = if (glm.getFitIntercept) {
Array("(Intercept)") ++ features
Array("(Intercept)") ++ summary.featureNames
} else {
features
summary.featureNames
}

val rCoefficients: Array[Double] = if (summary.isNormalSolver) {
val rCoefficientStandardErrors = if (glm.getFitIntercept) {
Array(summary.coefficientStandardErrors.last) ++
summary.coefficientStandardErrors.dropRight(1)
} else {
summary.coefficientStandardErrors
}

val rTValues = if (glm.getFitIntercept) {
Array(summary.tValues.last) ++ summary.tValues.dropRight(1)
} else {
summary.tValues
}

val rPValues = if (glm.getFitIntercept) {
Array(summary.pValues.last) ++ summary.pValues.dropRight(1)
} else {
summary.pValues
}

if (glm.getFitIntercept) {
Array(glm.intercept) ++ glm.coefficients.toArray ++
rCoefficientStandardErrors ++ rTValues ++ rPValues
} else {
glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues
}
summary.coefficientsWithStatistics.map(_._2) ++
summary.coefficientsWithStatistics.map(_._3) ++
summary.coefficientsWithStatistics.map(_._4) ++
summary.coefficientsWithStatistics.map(_._5)
} else {
if (glm.getFitIntercept) {
Array(glm.intercept) ++ glm.coefficients.toArray
Expand Down
Loading