diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index ea52e9fe6c1c1..88256b810bf04 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -258,7 +258,11 @@ public Properties cryptoConf() { } /** - * The max number of chunks allowed to being transferred at the same time on shuffle service. + * The max number of chunks allowed to be transferred at the same time on shuffle service. + * Note that new incoming connections will be closed when the max number is hit. The client will + * retry according to the shuffle retry configs (see `spark.shuffle.io.maxRetries` and + * `spark.shuffle.io.retryWait`), if those limits are reached the task will fail with fetch + * failure. */ public long maxChunksBeingTransferred() { return conf.getLong("spark.shuffle.maxChunksBeingTransferred", Long.MAX_VALUE); diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java index 1fb987a8a7aa7..1ed57116bc7bf 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -105,7 +105,7 @@ private class ExtendedChannelPromise extends DefaultChannelPromise { private List listeners = new ArrayList<>(); private boolean success; - public ExtendedChannelPromise(Channel channel) { + ExtendedChannelPromise(Channel channel) { super(channel); success = false; } @@ -127,7 +127,9 @@ public void finish(boolean success) { listeners.forEach(listener -> { try { listener.operationComplete(this); - } catch (Exception e) { } + } catch (Exception e) { + // do nothing + } }); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 338faaadb33d4..da6c55d9b8ac3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -120,14 +120,16 @@ final class ShuffleExternalSorter extends MemoryConsumer { this.taskContext = taskContext; this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided - this.fileBufferSizeBytes = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; + this.fileBufferSizeBytes = + (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.numElementsForSpillThreshold = conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", 1024 * 1024 * 1024); this.writeMetrics = writeMetrics; this.inMemSorter = new ShuffleInMemorySorter( this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true)); this.peakMemoryUsedBytes = getMemoryUsage(); - this.diskWriteBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE()); + this.diskWriteBufferSize = + (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE()); } /** diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 3b6200e74f1e1..610ace30f8a62 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -258,6 +258,11 @@ private MapIterator(int numRecords, Location loc, boolean destructive) { this.destructive = destructive; if (destructive) { destructiveIterator = this; + // longArray will not be used anymore if destructive is true, release it now. + if (longArray != null) { + freeArray(longArray); + longArray = null; + } } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index a6e858ca72021..e2059cec132d2 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.LinkedList; import java.util.Queue; +import java.util.function.Supplier; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; @@ -48,8 +49,16 @@ public final class UnsafeExternalSorter extends MemoryConsumer { @Nullable private final PrefixComparator prefixComparator; + + /** + * {@link RecordComparator} may probably keep the reference to the records they compared last + * time, so we should not keep a {@link RecordComparator} instance inside + * {@link UnsafeExternalSorter}, because {@link UnsafeExternalSorter} is referenced by + * {@link TaskContext} and thus can not be garbage collected until the end of the task. + */ @Nullable - private final RecordComparator recordComparator; + private final Supplier recordComparatorSupplier; + private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; private final SerializerManager serializerManager; @@ -90,14 +99,14 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter( BlockManager blockManager, SerializerManager serializerManager, TaskContext taskContext, - RecordComparator recordComparator, + Supplier recordComparatorSupplier, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, long numElementsForSpillThreshold, UnsafeInMemorySorter inMemorySorter) throws IOException { UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, - serializerManager, taskContext, recordComparator, prefixComparator, initialSize, + serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, numElementsForSpillThreshold, pageSizeBytes, inMemorySorter, false /* ignored */); sorter.spill(Long.MAX_VALUE, sorter); // The external sorter will be used to insert records, in-memory sorter is not needed. @@ -110,14 +119,14 @@ public static UnsafeExternalSorter create( BlockManager blockManager, SerializerManager serializerManager, TaskContext taskContext, - RecordComparator recordComparator, + Supplier recordComparatorSupplier, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, long numElementsForSpillThreshold, boolean canUseRadixSort) { return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, - taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, + taskContext, recordComparatorSupplier, prefixComparator, initialSize, pageSizeBytes, numElementsForSpillThreshold, null, canUseRadixSort); } @@ -126,7 +135,7 @@ private UnsafeExternalSorter( BlockManager blockManager, SerializerManager serializerManager, TaskContext taskContext, - RecordComparator recordComparator, + Supplier recordComparatorSupplier, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, @@ -138,15 +147,24 @@ private UnsafeExternalSorter( this.blockManager = blockManager; this.serializerManager = serializerManager; this.taskContext = taskContext; - this.recordComparator = recordComparator; + this.recordComparatorSupplier = recordComparatorSupplier; this.prefixComparator = prefixComparator; // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024 this.fileBufferSizeBytes = 32 * 1024; if (existingInMemorySorter == null) { + RecordComparator comparator = null; + if (recordComparatorSupplier != null) { + comparator = recordComparatorSupplier.get(); + } this.inMemSorter = new UnsafeInMemorySorter( - this, taskMemoryManager, recordComparator, prefixComparator, initialSize, canUseRadixSort); + this, + taskMemoryManager, + comparator, + prefixComparator, + initialSize, + canUseRadixSort); } else { this.inMemSorter = existingInMemorySorter; } @@ -451,14 +469,14 @@ public void merge(UnsafeExternalSorter other) throws IOException { * after consuming this iterator. */ public UnsafeSorterIterator getSortedIterator() throws IOException { - assert(recordComparator != null); + assert(recordComparatorSupplier != null); if (spillWriters.isEmpty()) { assert(inMemSorter != null); readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); return readingIterator; } else { - final UnsafeSorterSpillMerger spillMerger = - new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size()); + final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger( + recordComparatorSupplier.get(), prefixComparator, spillWriters.size()); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager)); } diff --git a/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala index 3432700f11602..fe7438ac54f18 100644 --- a/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala +++ b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala @@ -37,13 +37,7 @@ private[r] class JVMObjectTracker { /** * Returns the JVM object associated with the input key or None if not found. */ - final def get(id: JVMObjectId): Option[Object] = this.synchronized { - if (objMap.containsKey(id)) { - Some(objMap.get(id)) - } else { - None - } - } + final def get(id: JVMObjectId): Option[Object] = Option(objMap.get(id)) /** * Returns the JVM object associated with the input key or throws an exception if not found. @@ -67,13 +61,7 @@ private[r] class JVMObjectTracker { /** * Removes and returns a JVM object with the specific ID from the tracker, or None if not found. */ - final def remove(id: JVMObjectId): Option[Object] = this.synchronized { - if (objMap.containsKey(id)) { - Some(objMap.remove(id)) - } else { - None - } - } + final def remove(id: JVMObjectId): Option[Object] = Option(objMap.remove(id)) /** * Number of JVM objects being tracked. diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index fd1521193fdee..3721b98d68685 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -514,7 +514,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S outStream.println( s""" |Options: - | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. + | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local + | (Default: local[*]). | --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or | on one of the worker machines inside the cluster ("cluster") | (Default: client). diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 2985c90119468..5435f59ea0d28 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -55,7 +55,7 @@ import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, Poi * Doubles; and * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that * can be saved as SequenceFiles. - * All operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)] + * All operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)]) * through implicit. * * Internally, each RDD is characterized by five main properties: diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index cd5db1a70f722..5330a688e63e3 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -154,7 +154,7 @@ private UnsafeExternalSorter newSorter() throws IOException { blockManager, serializerManager, taskContext, - recordComparator, + () -> recordComparator, prefixComparator, /* initialSize */ 1024, pageSizeBytes, @@ -440,7 +440,7 @@ public void testPeakMemoryUsed() throws Exception { blockManager, serializerManager, taskContext, - recordComparator, + () -> recordComparator, prefixComparator, 1024, pageSizeBytes, diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 4bacb385184c6..28971b87f403f 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -17,10 +17,11 @@ # limitations under the License. # -# Utility for creating well-formed pull request merges and pushing them to Apache. -# usage: ./apache-pr-merge.py (see config env vars below) +# Utility for creating well-formed pull request merges and pushing them to Apache +# Spark. +# usage: ./merge_spark_pr.py (see config env vars below) # -# This utility assumes you already have local a Spark git folder and that you +# This utility assumes you already have a local Spark git folder and that you # have added remotes corresponding to both (i) the github apache Spark # mirror and (ii) the apache git repo. diff --git a/docs/configuration.md b/docs/configuration.md index f4b6f46db5b66..500f980455b0e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -635,7 +635,11 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.maxChunksBeingTransferred Long.MAX_VALUE - The max number of chunks allowed to being transferred at the same time on shuffle service. + The max number of chunks allowed to be transferred at the same time on shuffle service. + Note that new incoming connections will be closed when the max number is hit. The client will + retry according to the shuffle retry configs (see spark.shuffle.io.maxRetries and + spark.shuffle.io.retryWait), if those limits are reached the task will fail with + fetch failure. diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index c0215c8fb62f6..26025984da64c 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -978,40 +978,40 @@ for details. Return a new RDD that contains the intersection of elements in the source dataset and the argument. - distinct([numTasks])) + distinct([numPartitions])) Return a new dataset that contains the distinct elements of the source dataset. - groupByKey([numTasks]) + groupByKey([numPartitions]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
Note: If you are grouping in order to perform an aggregation (such as a sum or average) over each key, using reduceByKey or aggregateByKey will yield much better performance.
Note: By default, the level of parallelism in the output depends on the number of partitions of the parent RDD. - You can pass an optional numTasks argument to set a different number of tasks. + You can pass an optional numPartitions argument to set a different number of tasks. - reduceByKey(func, [numTasks]) + reduceByKey(func, [numPartitions]) When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function func, which must be of type (V,V) => V. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) + aggregateByKey(zeroValue)(seqOp, combOp, [numPartitions]) When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - sortByKey([ascending], [numTasks]) + sortByKey([ascending], [numPartitions]) When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean ascending argument. - join(otherDataset, [numTasks]) + join(otherDataset, [numPartitions]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (V, W)) pairs with all pairs of elements for each key. Outer joins are supported through leftOuterJoin, rightOuterJoin, and fullOuterJoin. - cogroup(otherDataset, [numTasks]) + cogroup(otherDataset, [numPartitions]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (Iterable<V>, Iterable<W>)) tuples. This operation is also called groupWith. diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index d4ddcb16bdd0e..44ae52e81cd64 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -175,7 +175,7 @@ an input DStream using data received by the instance of custom receiver, as show {% highlight scala %} // Assuming ssc is the StreamingContext val customReceiverStream = ssc.receiverStream(new CustomReceiver(host, port)) -val words = lines.flatMap(_.split(" ")) +val words = customReceiverStream.flatMap(_.split(" ")) ... {% endhighlight %} @@ -187,7 +187,7 @@ The full source code is in the example [CustomReceiver.scala]({{site.SPARK_GITHU {% highlight java %} // Assuming ssc is the JavaStreamingContext JavaDStream customReceiverStream = ssc.receiverStream(new JavaCustomReceiver(host, port)); -JavaDStream words = lines.flatMap(s -> ...); +JavaDStream words = customReceiverStream.flatMap(s -> ...); ... {% endhighlight %} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 7cbcccf2720a3..05b8c3ab5456e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -34,6 +34,7 @@ import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} +import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -53,7 +54,8 @@ private[ml] trait ClassifierTypeTrait { /** * Params for [[OneVsRest]]. */ -private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait { +private[ml] trait OneVsRestParams extends PredictorParams + with ClassifierTypeTrait with HasWeightCol { /** * param for the base binary classifier that we reduce multiclass classification into. @@ -294,6 +296,18 @@ final class OneVsRest @Since("1.4.0") ( @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** + * Sets the value of param [[weightCol]]. + * + * This is ignored if weight is not supported by [[classifier]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * Default is not set, so all instances have weight one. + * + * @group setParam + */ + @Since("2.3.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) @@ -317,7 +331,20 @@ final class OneVsRest @Since("1.4.0") ( val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity) instr.logNumClasses(numClasses) - val multiclassLabeled = dataset.select($(labelCol), $(featuresCol)) + val weightColIsUsed = isDefined(weightCol) && $(weightCol).nonEmpty && { + getClassifier match { + case _: HasWeightCol => true + case c => + logWarning(s"weightCol is ignored, as it is not supported by $c now.") + false + } + } + + val multiclassLabeled = if (weightColIsUsed) { + dataset.select($(labelCol), $(featuresCol), $(weightCol)) + } else { + dataset.select($(labelCol), $(featuresCol)) + } // persist if underlying dataset is not persistent. val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -337,7 +364,13 @@ final class OneVsRest @Since("1.4.0") ( paramMap.put(classifier.labelCol -> labelColName) paramMap.put(classifier.featuresCol -> getFeaturesCol) paramMap.put(classifier.predictionCol -> getPredictionCol) - classifier.fit(trainingDataset, paramMap) + if (weightColIsUsed) { + val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol] + paramMap.put(classifier_.weightCol -> getWeightCol) + classifier_.fit(trainingDataset, paramMap) + } else { + classifier.fit(trainingDataset, paramMap) + } }.toArray[ClassificationModel[_, _]] instr.logNumFeatures(models.head.numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index ee1fc9b14ceaa..176a6cf852914 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -83,11 +83,7 @@ private[r] object GeneralizedLinearRegressionWrapper .setStringIndexerOrderType(stringIndexerOrderType) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) - // get labels and feature names from output schema - val schema = rFormulaModel.transform(data).schema - val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol)) - .attributes.get - val features = featureAttrs.map(_.name.get) + // assemble and fit the pipeline val glr = new GeneralizedLinearRegression() .setFamily(family) @@ -113,37 +109,16 @@ private[r] object GeneralizedLinearRegressionWrapper val summary = glm.summary val rFeatures: Array[String] = if (glm.getFitIntercept) { - Array("(Intercept)") ++ features + Array("(Intercept)") ++ summary.featureNames } else { - features + summary.featureNames } val rCoefficients: Array[Double] = if (summary.isNormalSolver) { - val rCoefficientStandardErrors = if (glm.getFitIntercept) { - Array(summary.coefficientStandardErrors.last) ++ - summary.coefficientStandardErrors.dropRight(1) - } else { - summary.coefficientStandardErrors - } - - val rTValues = if (glm.getFitIntercept) { - Array(summary.tValues.last) ++ summary.tValues.dropRight(1) - } else { - summary.tValues - } - - val rPValues = if (glm.getFitIntercept) { - Array(summary.pValues.last) ++ summary.pValues.dropRight(1) - } else { - summary.pValues - } - - if (glm.getFitIntercept) { - Array(glm.intercept) ++ glm.coefficients.toArray ++ - rCoefficientStandardErrors ++ rTValues ++ rPValues - } else { - glm.coefficients.toArray ++ rCoefficientStandardErrors ++ rTValues ++ rPValues - } + summary.coefficientsWithStatistics.map(_._2) ++ + summary.coefficientsWithStatistics.map(_._3) ++ + summary.coefficientsWithStatistics.map(_._4) ++ + summary.coefficientsWithStatistics.map(_._5) } else { if (glm.getFitIntercept) { Array(glm.intercept) ++ glm.coefficients.toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 815607f0a76d2..917a4d238d467 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -20,12 +20,14 @@ package org.apache.spark.ml.regression import java.util.Locale import breeze.stats.{distributions => dist} +import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.feature.{Instance, OffsetInstance} import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} import org.apache.spark.ml.optim._ @@ -37,7 +39,6 @@ import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} - /** * Params for Generalized Linear Regression. */ @@ -141,6 +142,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam /** * Param for offset column name. If this is not set or empty, we treat all instance offsets * as 0.0. The feature specified as offset has a constant coefficient of 1.0. + * * @group param */ @Since("2.3.0") @@ -1204,6 +1206,21 @@ class GeneralizedLinearRegressionSummary private[regression] ( @Since("2.2.0") lazy val numInstances: Long = predictions.count() + + /** + * Name of features. If the name cannot be retrieved from attributes, + * set default names to feature column name with numbered suffix "_0", "_1", and so on. + */ + private[ml] lazy val featureNames: Array[String] = { + val featureAttrs = AttributeGroup.fromStructField( + dataset.schema(model.getFeaturesCol)).attributes + if (featureAttrs.isDefined) { + featureAttrs.get.map(_.name.get) + } else { + Array.tabulate[String](origModel.numFeatures)((x: Int) => model.getFeaturesCol + "_" + x) + } + } + /** The numeric rank of the fitted linear model. */ @Since("2.0.0") lazy val rank: Long = if (model.getFitIntercept) { @@ -1458,4 +1475,96 @@ class GeneralizedLinearRegressionTrainingSummary private[regression] ( "No p-value available for this GeneralizedLinearRegressionModel") } } + + /** + * Coefficients with statistics: feature name, coefficients, standard error, tValue and pValue. + */ + private[ml] lazy val coefficientsWithStatistics: Array[ + (String, Double, Double, Double, Double)] = { + var featureNamesLocal = featureNames + var coefficientsArray = model.coefficients.toArray + var index = Array.range(0, coefficientsArray.length) + if (model.getFitIntercept) { + featureNamesLocal = featureNamesLocal :+ "(Intercept)" + coefficientsArray = coefficientsArray :+ model.intercept + // Reorder so that intercept comes first + index = (coefficientsArray.length - 1) +: index + } + index.map { i => + (featureNamesLocal(i), coefficientsArray(i), coefficientStandardErrors(i), + tValues(i), pValues(i)) + } + } + + override def toString: String = { + if (isNormalSolver) { + + def round(x: Double): String = { + BigDecimal(x).setScale(4, BigDecimal.RoundingMode.HALF_UP).toString + } + + val colNames = Array("Feature", "Estimate", "Std Error", "T Value", "P Value") + + val data = coefficientsWithStatistics.map { row => + val strRow = row.productIterator.map { cell => + val str = cell match { + case s: String => s + case n: Double => round(n) + } + // Truncate if length > 20 + if (str.length > 20) { + str.substring(0, 17) + "..." + } else { + str + } + } + strRow.toArray + } + + // Compute the width of each column + val colWidths = colNames.map(_.length) + data.foreach { strRow => + strRow.zipWithIndex.foreach { case (cell: String, i: Int) => + colWidths(i) = math.max(colWidths(i), cell.length) + } + } + + val sb = new StringBuilder + + // Output coefficients with statistics + sb.append("Coefficients:\n") + colNames.zipWithIndex.map { case (colName: String, i: Int) => + StringUtils.leftPad(colName, colWidths(i)) + }.addString(sb, "", " ", "\n") + + data.foreach { case strRow: Array[String] => + strRow.zipWithIndex.map { case (cell: String, i: Int) => + StringUtils.leftPad(cell.toString, colWidths(i)) + }.addString(sb, "", " ", "\n") + } + + sb.append("\n") + sb.append(s"(Dispersion parameter for ${family.name} family taken to be " + + s"${round(dispersion)})") + + sb.append("\n") + val nd = s"Null deviance: ${round(nullDeviance)} on $degreesOfFreedom degrees of freedom" + val rd = s"Residual deviance: ${round(deviance)} on $residualDegreeOfFreedom degrees of " + + "freedom" + val l = math.max(nd.length, rd.length) + sb.append(StringUtils.leftPad(nd, l)) + sb.append("\n") + sb.append(StringUtils.leftPad(rd, l)) + + if (family.name != "tweedie") { + sb.append("\n") + sb.append(s"AIC: " + round(aic)) + } + + sb.toString() + } else { + throw new UnsupportedOperationException( + "No summary available for this GeneralizedLinearRegressionModel") + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index c02e38ad64e3e..17f82827b74e6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -156,6 +156,16 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) } + test("SPARK-21306: OneVsRest should support setWeightCol") { + val dataset2 = dataset.withColumn("weight", lit(1)) + // classifier inherits hasWeightCol + val ova = new OneVsRest().setWeightCol("weight").setClassifier(new LogisticRegression()) + assert(ova.fit(dataset2) !== null) + // classifier doesn't inherit hasWeightCol + val ova2 = new OneVsRest().setWeightCol("weight").setClassifier(new DecisionTreeClassifier()) + assert(ova2.fit(dataset2) !== null) + } + test("OneVsRest.copy and OneVsRestModel.copy") { val lr = new LogisticRegression() .setMaxIter(1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index a47bd17f47bb1..df7dee869d058 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, OffsetInstance} -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{LabeledPoint, RFormula} import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -1524,6 +1524,87 @@ class GeneralizedLinearRegressionSuite .fit(datasetGaussianIdentity.as[LabeledPoint]) } + test("glm summary: feature name") { + // dataset1 with no attribute + val dataset1 = Seq( + Instance(2.0, 1.0, Vectors.dense(0.0, 5.0)), + Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)), + Instance(2.0, 5.0, Vectors.dense(2.0, 3.0)) + ).toDF() + + // dataset2 with attribute + val datasetTmp = Seq( + (2.0, 1.0, 0.0, 5.0), + (8.0, 2.0, 1.0, 7.0), + (3.0, 3.0, 2.0, 11.0), + (9.0, 4.0, 3.0, 13.0), + (2.0, 5.0, 2.0, 3.0) + ).toDF("y", "w", "x1", "x2") + val formula = new RFormula().setFormula("y ~ x1 + x2") + val dataset2 = formula.fit(datasetTmp).transform(datasetTmp) + + val expectedFeature = Seq(Array("features_0", "features_1"), Array("x1", "x2")) + + var idx = 0 + for (dataset <- Seq(dataset1, dataset2)) { + val model = new GeneralizedLinearRegression().fit(dataset) + model.summary.featureNames.zip(expectedFeature(idx)) + .foreach{ x => assert(x._1 === x._2) } + idx += 1 + } + } + + test("glm summary: coefficient with statistics") { + /* + R code: + + A <- matrix(c(0, 1, 2, 3, 2, 5, 7, 11, 13, 3), 5, 2) + b <- c(2, 8, 3, 9, 2) + df <- as.data.frame(cbind(A, b)) + model <- glm(formula = "b ~ .", data = df) + summary(model) + + Coefficients: + Estimate Std. Error t value Pr(>|t|) + (Intercept) 0.7903 4.0129 0.197 0.862 + V1 0.2258 2.1153 0.107 0.925 + V2 0.4677 0.5815 0.804 0.506 + */ + val dataset = Seq( + Instance(2.0, 1.0, Vectors.dense(0.0, 5.0)), + Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(9.0, 4.0, Vectors.dense(3.0, 13.0)), + Instance(2.0, 5.0, Vectors.dense(2.0, 3.0)) + ).toDF() + + val expectedFeature = Seq(Array("features_0", "features_1"), + Array("(Intercept)", "features_0", "features_1")) + val expectedEstimate = Seq(Vectors.dense(0.2884, 0.538), + Vectors.dense(0.7903, 0.2258, 0.4677)) + val expectedStdError = Seq(Vectors.dense(1.724, 0.3787), + Vectors.dense(4.0129, 2.1153, 0.5815)) + + var idx = 0 + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression() + .setFamily("gaussian") + .setFitIntercept(fitIntercept) + val model = trainer.fit(dataset) + val coefficientsWithStatistics = model.summary.coefficientsWithStatistics + + coefficientsWithStatistics.map(_._1).zip(expectedFeature(idx)).foreach { x => + assert(x._1 === x._2, "Feature name mismatch in coefficientsWithStatistics") } + assert(Vectors.dense(coefficientsWithStatistics.map(_._2)) ~= expectedEstimate(idx) + absTol 1E-3, "Coefficients mismatch in coefficientsWithStatistics") + assert(Vectors.dense(coefficientsWithStatistics.map(_._3)) ~= expectedStdError(idx) + absTol 1E-3, "Standard error mismatch in coefficientsWithStatistics") + idx += 1 + } + } + test("generalized linear regression: regularization parameter") { /* R code: diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 82207f664480a..4af6f71e19257 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1447,7 +1447,7 @@ def weights(self): return self._call_java("weights") -class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol): +class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol): """ Parameters for OneVsRest and OneVsRestModel. """ @@ -1517,10 +1517,10 @@ class OneVsRest(Estimator, OneVsRestParams, JavaMLReadable, JavaMLWritable): @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - classifier=None): + classifier=None, weightCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - classifier=None) + classifier=None, weightCol=None) """ super(OneVsRest, self).__init__() kwargs = self._input_kwargs @@ -1528,9 +1528,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @keyword_only @since("2.0.0") - def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None): + def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, + classifier=None, weightCol=None): """ - setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None): + setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \ + classifier=None, weightCol=None): Sets params for OneVsRest. """ kwargs = self._input_kwargs @@ -1546,7 +1548,18 @@ def _fit(self, dataset): numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1 - multiclassLabeled = dataset.select(labelCol, featuresCol) + weightCol = None + if (self.isDefined(self.weightCol) and self.getWeightCol()): + if isinstance(classifier, HasWeightCol): + weightCol = self.getWeightCol() + else: + warnings.warn("weightCol is ignored, " + "as it is not supported by {} now.".format(classifier)) + + if weightCol: + multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol) + else: + multiclassLabeled = dataset.select(labelCol, featuresCol) # persist if underlying dataset is not persistent. handlePersistence = \ @@ -1562,6 +1575,8 @@ def trainSingleClass(index): paramMap = dict([(classifier.labelCol, binaryLabelCol), (classifier.featuresCol, featuresCol), (classifier.predictionCol, predictionCol)]) + if weightCol: + paramMap[classifier.weightCol] = weightCol return classifier.fit(trainingDataset, paramMap) # TODO: Parallel training for all classes. diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6c71e69c9b5f9..a9ca346fa5d83 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1394,6 +1394,20 @@ def test_output_columns(self): output = model.transform(df) self.assertEqual(output.columns, ["label", "features", "prediction"]) + def test_support_for_weightCol(self): + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0), + (1.0, Vectors.sparse(2, [], []), 1.0), + (2.0, Vectors.dense(0.5, 0.5), 1.0)], + ["label", "features", "weight"]) + # classifier inherits hasWeightCol + lr = LogisticRegression(maxIter=5, regParam=0.01) + ovr = OneVsRest(classifier=lr, weightCol="weight") + self.assertIsNotNone(ovr.fit(df)) + # classifier doesn't inherit hasWeightCol + dt = DecisionTreeClassifier() + ovr2 = OneVsRest(classifier=dt, weightCol="weight") + self.assertIsNotNone(ovr2.fit(df)) + class HashingTFTest(SparkSessionTestCase): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1c1a0cad49625..cfd9c558ff67e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1241,26 +1241,29 @@ def test_struct_type(self): struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1.fieldNames(), struct2.names) self.assertEqual(struct1, struct2) struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1.fieldNames(), struct2.names) self.assertNotEqual(struct1, struct2) struct1 = (StructType().add(StructField("f1", StringType(), True)) .add(StructField("f2", StringType(), True, None))) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1.fieldNames(), struct2.names) self.assertEqual(struct1, struct2) struct1 = (StructType().add(StructField("f1", StringType(), True)) .add(StructField("f2", StringType(), True, None))) struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1.fieldNames(), struct2.names) self.assertNotEqual(struct1, struct2) # Catch exception raised during improper construction - with self.assertRaises(ValueError): - struct1 = StructType().add("name") + self.assertRaises(ValueError, lambda: StructType().add("name")) struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) for field in struct1: @@ -1273,12 +1276,9 @@ def test_struct_type(self): self.assertIs(struct1["f1"], struct1.fields[0]) self.assertIs(struct1[0], struct1.fields[0]) self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1])) - with self.assertRaises(KeyError): - not_a_field = struct1["f9"] - with self.assertRaises(IndexError): - not_a_field = struct1[9] - with self.assertRaises(TypeError): - not_a_field = struct1[9.9] + self.assertRaises(KeyError, lambda: struct1["f9"]) + self.assertRaises(IndexError, lambda: struct1[9]) + self.assertRaises(TypeError, lambda: struct1[9.9]) def test_parse_datatype_string(self): from pyspark.sql.types import _all_atomic_types, _parse_datatype_string @@ -3018,8 +3018,8 @@ def assertFramesEqual(self, df_with_arrow, df_without): self.assertTrue(df_without.equals(df_with_arrow), msg=msg) def test_unsupported_datatype(self): - schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)]) - df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema) + schema = StructType([StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) with QuietTest(self.sc): self.assertRaises(Exception, lambda: df.toPandas()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index c376805c32738..ecb8eb9a2f2fa 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -446,9 +446,12 @@ class StructType(DataType): This is the data type representing a :class:`Row`. - Iterating a :class:`StructType` will iterate its :class:`StructField`s. + Iterating a :class:`StructType` will iterate its :class:`StructField`\\s. A contained :class:`StructField` can be accessed by name or position. + .. note:: `names` attribute is deprecated in 2.3. Use `fieldNames` method instead + to get a list of field names. + >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct1["f1"] StructField(f1,StringType,true) @@ -563,6 +566,16 @@ def jsonValue(self): def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) + def fieldNames(self): + """ + Returns all field names in a list. + + >>> struct = StructType([StructField("f1", StringType(), True)]) + >>> struct.fieldNames() + ['f1'] + """ + return list(self.names) + def needConversion(self): # We need convert Row()/namedtuple into tuple() return True diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index fc925022b2718..ca6a3ef3ebbb5 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -90,6 +90,9 @@ private[spark] class ApplicationMaster( @volatile private var reporterThread: Thread = _ @volatile private var allocator: YarnAllocator = _ + // A flag to check whether user has initialized spark context + @volatile private var registered = false + private val userClassLoader = { val classpath = Client.getUserClasspath(sparkConf) val urls = classpath.map { entry => @@ -319,7 +322,7 @@ private[spark] class ApplicationMaster( */ final def unregister(status: FinalApplicationStatus, diagnostics: String = null): Unit = { synchronized { - if (!unregistered) { + if (registered && !unregistered) { logInfo(s"Unregistering ApplicationMaster with $status" + Option(diagnostics).map(msg => s" (diag message: $msg)").getOrElse("")) unregistered = true @@ -332,10 +335,15 @@ private[spark] class ApplicationMaster( synchronized { if (!finished) { val inShutdown = ShutdownHookManager.inShutdown() - logInfo(s"Final app status: $status, exitCode: $code" + + if (registered) { + exitCode = code + finalStatus = status + } else { + finalStatus = FinalApplicationStatus.FAILED + exitCode = ApplicationMaster.EXIT_SC_NOT_INITED + } + logInfo(s"Final app status: $finalStatus, exitCode: $exitCode" + Option(msg).map(msg => s", (reason: $msg)").getOrElse("")) - exitCode = code - finalStatus = status finalMsg = msg finished = true if (!inShutdown && Thread.currentThread() != reporterThread && reporterThread != null) { @@ -439,12 +447,11 @@ private[spark] class ApplicationMaster( sc.getConf.get("spark.driver.port"), isClusterMode = true) registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl), securityMgr) + registered = true } else { // Sanity check; should never happen in normal operation, since sc should only be null // if the user app did not create a SparkContext. - if (!finished) { - throw new IllegalStateException("SparkContext is null but app is still running!") - } + throw new IllegalStateException("User did not initialize spark context!") } userClassThread.join() } catch { diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index ef9f88a9026c9..4534b7dcf6399 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -474,7 +474,7 @@ identifierComment relationPrimary : tableIdentifier sample? tableAlias #tableName - | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery + | '(' queryNoWith ')' sample? tableAlias #aliasedQuery | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation | inlineTable #inlineTableDefault2 | functionTable #tableValuedFunction diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java index a88a315bf479f..df52f9c2d5496 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java @@ -62,7 +62,7 @@ public UnsafeRow appendRow(Object kbase, long koff, int klen, keyRowId = numRows; keyRow.pointTo(base, recordOffset, klen); - valueRow.pointTo(base, recordOffset + klen, vlen + 4); + valueRow.pointTo(base, recordOffset + klen, vlen); numRows++; return valueRow; } @@ -95,7 +95,7 @@ protected UnsafeRow getValueFromKey(int rowId) { getKeyRow(rowId); } assert(rowId >= 0); - valueRow.pointTo(base, keyRow.getBaseOffset() + klen, vlen + 4); + valueRow.pointTo(base, keyRow.getBaseOffset() + klen, vlen); return valueRow; } @@ -131,7 +131,7 @@ public boolean next() { } key.pointTo(base, offsetInPage, klen); - value.pointTo(base, offsetInPage + klen, vlen + 4); + value.pointTo(base, offsetInPage + klen, vlen); offsetInPage += recordLength; recordsInPage -= 1; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 56994fafe064b..ec947d7580282 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -167,6 +167,7 @@ public UnsafeRow() {} */ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { assert numFields >= 0 : "numFields (" + numFields + ") should >= 0"; + assert sizeInBytes % 8 == 0 : "sizeInBytes (" + sizeInBytes + ") should be a multiple of 8"; this.baseObject = baseObject; this.baseOffset = baseOffset; this.sizeInBytes = sizeInBytes; @@ -183,6 +184,7 @@ public void pointTo(byte[] buf, int sizeInBytes) { } public void setTotalSize(int sizeInBytes) { + assert sizeInBytes % 8 == 0 : "sizeInBytes (" + sizeInBytes + ") should be a multiple of 8"; this.sizeInBytes = sizeInBytes; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java index ea4f984be24e5..905e6820ce6e2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/VariableLengthRowBasedKeyValueBatch.java @@ -65,7 +65,7 @@ public UnsafeRow appendRow(Object kbase, long koff, int klen, keyRowId = numRows; keyRow.pointTo(base, recordOffset + 8, klen); - valueRow.pointTo(base, recordOffset + 8 + klen, vlen + 4); + valueRow.pointTo(base, recordOffset + 8 + klen, vlen); numRows++; return valueRow; } @@ -102,7 +102,7 @@ public UnsafeRow getValueFromKey(int rowId) { long offset = keyRow.getBaseOffset(); int klen = keyRow.getSizeInBytes(); int vlen = Platform.getInt(base, offset - 8) - klen - 4; - valueRow.pointTo(base, offset + klen, vlen + 4); + valueRow.pointTo(base, offset + klen, vlen); return valueRow; } @@ -146,7 +146,7 @@ public boolean next() { currentvlen = totalLength - currentklen; key.pointTo(base, offsetInPage + 8, currentklen); - value.pointTo(base, offsetInPage + 8 + currentklen, currentvlen + 4); + value.pointTo(base, offsetInPage + 8 + currentklen, currentvlen); offsetInPage += 8 + totalLength + 8; recordsInPage -= 1; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index aadfcaa56cc2d..12a123ee0bcff 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -84,7 +84,7 @@ public UnsafeExternalRowSorter( sparkEnv.blockManager(), sparkEnv.serializerManager(), taskContext, - new RowComparator(ordering, schema.length()), + () -> new RowComparator(ordering, schema.length()), prefixComparator, sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize", DEFAULT_INITIAL_SORT_BUFFER_SIZE), @@ -195,12 +195,10 @@ public Iterator sort(Iterator inputIterator) throws IOExce private static final class RowComparator extends RecordComparator { private final Ordering ordering; - private final int numFields; private final UnsafeRow row1; private final UnsafeRow row2; RowComparator(Ordering ordering, int numFields) { - this.numFields = numFields; this.row1 = new UnsafeRow(numFields); this.row2 = new UnsafeRow(numFields); this.ordering = ordering; @@ -208,9 +206,10 @@ private static final class RowComparator extends RecordComparator { @Override public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { - // TODO: Why are the sizes -1? - row1.pointTo(baseObj1, baseOff1, -1); - row2.pointTo(baseObj2, baseOff2, -1); + // Note that since ordering doesn't need the total length of the record, we just pass 0 + // into the row. + row1.pointTo(baseObj1, baseOff1, 0); + row2.pointTo(baseObj2, baseOff2, 0); return ordering.compare(row1, row2); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 913d846a8c23b..a6d297cfd6538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -141,6 +141,7 @@ class Analyzer( ResolveFunctions :: ResolveAliases :: ResolveSubquery :: + ResolveSubqueryColumnAliases :: ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: @@ -1323,6 +1324,30 @@ class Analyzer( } } + /** + * Replaces unresolved column aliases for a subquery with projections. + */ + object ResolveSubqueryColumnAliases extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case u @ UnresolvedSubqueryColumnAliases(columnNames, child) if child.resolved => + // Resolves output attributes if a query has alias names in its subquery: + // e.g., SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2) + val outputAttrs = child.output + // Checks if the number of the aliases equals to the number of output columns + // in the subquery. + if (columnNames.size != outputAttrs.size) { + u.failAnalysis("Number of column aliases does not match number of columns. " + + s"Number of column aliases: ${columnNames.size}; " + + s"number of columns: ${outputAttrs.size}.") + } + val aliases = outputAttrs.zip(columnNames).map { case (attr, aliasName) => + Alias(attr, aliasName)() + } + Project(aliases, child) + } + } + /** * Turns projections that contain aggregate expressions into aggregations. */ @@ -2234,7 +2259,9 @@ object EliminateUnions extends Rule[LogicalPlan] { /** * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level * expression in Project(project list) or Aggregate(aggregate expressions) or - * Window(window expressions). + * Window(window expressions). Notice that if an expression has other expression parameters which + * are not in its `children`, e.g. `RuntimeReplaceable`, the transformation for Aliases in this + * rule can't work for those parameters. */ object CleanupAliases extends Rule[LogicalPlan] { private def trimAliases(e: Expression): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 85c52792ef659..e235689cc36ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -108,11 +108,9 @@ trait CheckAnalysis extends PredicateHelper { case w @ WindowExpression(AggregateExpression(_, _, true, _), _) => failAnalysis(s"Distinct window functions are not supported: $w") - case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order, - SpecifiedWindowFrame(frame, - FrameBoundary(l), - FrameBoundary(h)))) - if order.isEmpty || frame != RowFrame || l != h => + case w @ WindowExpression(_: OffsetWindowFunction, + WindowSpecDefinition(_, order, frame: SpecifiedWindowFrame)) + if order.isEmpty || !frame.isOffset => failAnalysis("An offset window function can only be evaluated in an ordered " + s"row-based window frame with a single offset: $w") @@ -121,15 +119,10 @@ trait CheckAnalysis extends PredicateHelper { // function. e match { case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction => + w case _ => failAnalysis(s"Expression '$e' not supported within a window function.") } - // Make sure the window specification is valid. - s.validate match { - case Some(m) => - failAnalysis(s"Window specification $s is not valid because $m") - case None => w - } case s: SubqueryExpression => checkSubqueryExpression(operator, s) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a78e1c98e89de..25af014f67fe9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -59,6 +59,7 @@ object TypeCoercion { PropagateTypes :: ImplicitTypeCasts :: DateTimeOperations :: + WindowFrameCoercion :: Nil // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. @@ -805,4 +806,26 @@ object TypeCoercion { Option(ret) } } + + /** + * Cast WindowFrame boundaries to the type they operate upon. + */ + object WindowFrameCoercion extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper)) + if order.resolved => + s.copy(frameSpecification = SpecifiedWindowFrame( + RangeFrame, + createBoundaryCast(lower, order.dataType), + createBoundaryCast(upper, order.dataType))) + } + + private def createBoundaryCast(boundary: Expression, dt: DataType): Expression = { + boundary match { + case e: SpecialFrameBoundary => e + case e: Expression if e.dataType != dt && Cast.canCast(e.dataType, dt) => Cast(e, dt) + case _ => boundary + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index fb322697c7c68..b7a704dc8453a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.parser.ParserUtils -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.types.{DataType, Metadata, StructType} @@ -422,6 +422,27 @@ case class UnresolvedAlias( override lazy val resolved = false } +/** + * Aliased column names resolved by positions for subquery. We could add alias names for output + * columns in the subquery: + * {{{ + * // Assign alias names for output columns + * SELECT col1, col2 FROM testData AS t(col1, col2); + * }}} + * + * @param outputColumnNames the [[LogicalPlan]] on which this subquery column aliases apply. + * @param child the logical plan of this subquery. + */ +case class UnresolvedSubqueryColumnAliases( + outputColumnNames: Seq[String], + child: LogicalPlan) + extends UnaryNode { + + override def output: Seq[Attribute] = Nil + + override lazy val resolved = false +} + /** * Holds the deserializer expression and the attributes that are available during the resolution * for it. Deserializer expression is a special kind of expression that is not always resolved by diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b847ef7bfaa97..74c4cddf2b47e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -241,6 +241,10 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable { override def nullable: Boolean = child.nullable override def foldable: Boolean = child.foldable override def dataType: DataType = child.dataType + // As this expression gets replaced at optimization with its `child" expression, + // two `RuntimeReplaceable` are considered to be semantically equal if their "child" expressions + // are semantically equal. + override lazy val canonicalized: Expression = child.canonicalized } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 4c8b177237d23..1a48995358af7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -74,6 +74,13 @@ package object expressions { def initialize(partitionIndex: Int): Unit = {} } + /** + * An identity projection. This returns the input row. + */ + object IdentityProjection extends Projection { + override def apply(row: InternalRow): InternalRow = row + } + /** * Converts a [[InternalRow]] to another Row given a sequence of expression that define each * column of the new row. If the schema of the input row is specified, then the given expression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 88afd43223d1d..a829dccfd3e36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp} @@ -43,34 +42,7 @@ case class WindowSpecDefinition( orderSpec: Seq[SortOrder], frameSpecification: WindowFrame) extends Expression with WindowSpec with Unevaluable { - def validate: Option[String] = frameSpecification match { - case UnspecifiedFrame => - Some("Found a UnspecifiedFrame. It should be converted to a SpecifiedWindowFrame " + - "during analysis. Please file a bug report.") - case frame: SpecifiedWindowFrame => frame.validate.orElse { - def checkValueBasedBoundaryForRangeFrame(): Option[String] = { - if (orderSpec.length > 1) { - // It is not allowed to have a value-based PRECEDING and FOLLOWING - // as the boundary of a Range Window Frame. - Some("This Range Window Frame only accepts at most one ORDER BY expression.") - } else if (orderSpec.nonEmpty && !orderSpec.head.dataType.isInstanceOf[NumericType]) { - Some("The data type of the expression in the ORDER BY clause should be a numeric type.") - } else { - None - } - } - - (frame.frameType, frame.frameStart, frame.frameEnd) match { - case (RangeFrame, vp: ValuePreceding, _) => checkValueBasedBoundaryForRangeFrame() - case (RangeFrame, vf: ValueFollowing, _) => checkValueBasedBoundaryForRangeFrame() - case (RangeFrame, _, vp: ValuePreceding) => checkValueBasedBoundaryForRangeFrame() - case (RangeFrame, _, vf: ValueFollowing) => checkValueBasedBoundaryForRangeFrame() - case (_, _, _) => None - } - } - } - - override def children: Seq[Expression] = partitionSpec ++ orderSpec + override def children: Seq[Expression] = partitionSpec ++ orderSpec :+ frameSpecification override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess && @@ -78,23 +50,46 @@ case class WindowSpecDefinition( override def nullable: Boolean = true override def foldable: Boolean = false - override def dataType: DataType = throw new UnsupportedOperationException + override def dataType: DataType = throw new UnsupportedOperationException("dataType") - override def sql: String = { - val partition = if (partitionSpec.isEmpty) { - "" - } else { - "PARTITION BY " + partitionSpec.map(_.sql).mkString(", ") + " " + override def checkInputDataTypes(): TypeCheckResult = { + frameSpecification match { + case UnspecifiedFrame => + TypeCheckFailure( + "Cannot use an UnspecifiedFrame. This should have been converted during analysis. " + + "Please file a bug report.") + case f: SpecifiedWindowFrame if f.frameType == RangeFrame && !f.isUnbounded && + orderSpec.isEmpty => + TypeCheckFailure( + "A range window frame cannot be used in an unordered window specification.") + case f: SpecifiedWindowFrame if f.frameType == RangeFrame && f.isValueBound && + orderSpec.size > 1 => + TypeCheckFailure( + s"A range window frame with value boundaries cannot be used in a window specification " + + s"with multiple order by expressions: ${orderSpec.mkString(",")}") + case f: SpecifiedWindowFrame if f.frameType == RangeFrame && f.isValueBound && + !isValidFrameType(f.valueBoundary.head.dataType) => + TypeCheckFailure( + s"The data type '${orderSpec.head.dataType}' used in the order specification does " + + s"not match the data type '${f.valueBoundary.head.dataType}' which is used in the " + + "range frame.") + case _ => TypeCheckSuccess } + } - val order = if (orderSpec.isEmpty) { - "" - } else { - "ORDER BY " + orderSpec.map(_.sql).mkString(", ") + " " + override def sql: String = { + def toSql(exprs: Seq[Expression], prefix: String): Seq[String] = { + Seq(exprs).filter(_.nonEmpty).map(_.map(_.sql).mkString(prefix, ", ", "")) } - s"($partition$order${frameSpecification.toString})" + val elements = + toSql(partitionSpec, "PARTITION BY ") ++ + toSql(orderSpec, "ORDER BY ") ++ + Seq(frameSpecification.sql) + elements.mkString("(", " ", ")") } + + private def isValidFrameType(ft: DataType): Boolean = orderSpec.head.dataType == ft } /** @@ -106,22 +101,26 @@ case class WindowSpecReference(name: String) extends WindowSpec /** * The trait used to represent the type of a Window Frame. */ -sealed trait FrameType +sealed trait FrameType { + def inputType: AbstractDataType + def sql: String +} /** - * RowFrame treats rows in a partition individually. When a [[ValuePreceding]] - * or a [[ValueFollowing]] is used as its [[FrameBoundary]], the value is considered - * as a physical offset. + * RowFrame treats rows in a partition individually. Values used in a row frame are considered + * to be physical offsets. * For example, `ROW BETWEEN 1 PRECEDING AND 1 FOLLOWING` represents a 3-row frame, * from the row that precedes the current row to the row that follows the current row. */ -case object RowFrame extends FrameType +case object RowFrame extends FrameType { + override def inputType: AbstractDataType = IntegerType + override def sql: String = "ROWS" +} /** - * RangeFrame treats rows in a partition as groups of peers. - * All rows having the same `ORDER BY` ordering are considered as peers. - * When a [[ValuePreceding]] or a [[ValueFollowing]] is used as its [[FrameBoundary]], - * the value is considered as a logical offset. + * RangeFrame treats rows in a partition as groups of peers. All rows having the same `ORDER BY` + * ordering are considered as peers. Values used in a range frame are considered to be logical + * offsets. * For example, assuming the value of the current row's `ORDER BY` expression `expr` is `v`, * `RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING` represents a frame containing rows whose values * `expr` are in the range of [v-1, v+1]. @@ -129,138 +128,144 @@ case object RowFrame extends FrameType * If `ORDER BY` clause is not defined, all rows in the partition are considered as peers * of the current row. */ -case object RangeFrame extends FrameType - -/** - * The trait used to represent the type of a Window Frame Boundary. - */ -sealed trait FrameBoundary { - def notFollows(other: FrameBoundary): Boolean +case object RangeFrame extends FrameType { + override def inputType: AbstractDataType = NumericType + override def sql: String = "RANGE" } /** - * Extractor for making working with frame boundaries easier. + * The trait used to represent special boundaries used in a window frame. */ -object FrameBoundary { - def apply(boundary: FrameBoundary): Option[Int] = unapply(boundary) - def unapply(boundary: FrameBoundary): Option[Int] = boundary match { - case CurrentRow => Some(0) - case ValuePreceding(offset) => Some(-offset) - case ValueFollowing(offset) => Some(offset) - case _ => None - } +sealed trait SpecialFrameBoundary extends Expression with Unevaluable { + override def children: Seq[Expression] = Nil + override def dataType: DataType = NullType + override def foldable: Boolean = false + override def nullable: Boolean = false } -/** UNBOUNDED PRECEDING boundary. */ -case object UnboundedPreceding extends FrameBoundary { - def notFollows(other: FrameBoundary): Boolean = other match { - case UnboundedPreceding => true - case vp: ValuePreceding => true - case CurrentRow => true - case vf: ValueFollowing => true - case UnboundedFollowing => true - } - - override def toString: String = "UNBOUNDED PRECEDING" +/** UNBOUNDED boundary. */ +case object UnboundedPreceding extends SpecialFrameBoundary { + override def sql: String = "UNBOUNDED PRECEDING" } -/** PRECEDING boundary. */ -case class ValuePreceding(value: Int) extends FrameBoundary { - def notFollows(other: FrameBoundary): Boolean = other match { - case UnboundedPreceding => false - case ValuePreceding(anotherValue) => value >= anotherValue - case CurrentRow => true - case vf: ValueFollowing => true - case UnboundedFollowing => true - } - - override def toString: String = s"$value PRECEDING" +case object UnboundedFollowing extends SpecialFrameBoundary { + override def sql: String = "UNBOUNDED FOLLOWING" } /** CURRENT ROW boundary. */ -case object CurrentRow extends FrameBoundary { - def notFollows(other: FrameBoundary): Boolean = other match { - case UnboundedPreceding => false - case vp: ValuePreceding => false - case CurrentRow => true - case vf: ValueFollowing => true - case UnboundedFollowing => true - } - - override def toString: String = "CURRENT ROW" -} - -/** FOLLOWING boundary. */ -case class ValueFollowing(value: Int) extends FrameBoundary { - def notFollows(other: FrameBoundary): Boolean = other match { - case UnboundedPreceding => false - case vp: ValuePreceding => false - case CurrentRow => false - case ValueFollowing(anotherValue) => value <= anotherValue - case UnboundedFollowing => true - } - - override def toString: String = s"$value FOLLOWING" -} - -/** UNBOUNDED FOLLOWING boundary. */ -case object UnboundedFollowing extends FrameBoundary { - def notFollows(other: FrameBoundary): Boolean = other match { - case UnboundedPreceding => false - case vp: ValuePreceding => false - case CurrentRow => false - case vf: ValueFollowing => false - case UnboundedFollowing => true - } - - override def toString: String = "UNBOUNDED FOLLOWING" +case object CurrentRow extends SpecialFrameBoundary { + override def sql: String = "CURRENT ROW" } /** * Represents a window frame. */ -sealed trait WindowFrame +sealed trait WindowFrame extends Expression with Unevaluable { + override def children: Seq[Expression] = Nil + override def dataType: DataType = throw new UnsupportedOperationException("dataType") + override def foldable: Boolean = false + override def nullable: Boolean = false +} /** Used as a placeholder when a frame specification is not defined. */ case object UnspecifiedFrame extends WindowFrame -/** A specified Window Frame. */ +/** + * A specified Window Frame. The val lower/uppper can be either a foldable [[Expression]] or a + * [[SpecialFrameBoundary]]. + */ case class SpecifiedWindowFrame( frameType: FrameType, - frameStart: FrameBoundary, - frameEnd: FrameBoundary) extends WindowFrame { - - /** If this WindowFrame is valid or not. */ - def validate: Option[String] = (frameType, frameStart, frameEnd) match { - case (_, UnboundedFollowing, _) => - Some(s"$UnboundedFollowing is not allowed as the start of a Window Frame.") - case (_, _, UnboundedPreceding) => - Some(s"$UnboundedPreceding is not allowed as the end of a Window Frame.") - // case (RowFrame, start, end) => ??? RowFrame specific rule - // case (RangeFrame, start, end) => ??? RangeFrame specific rule - case (_, start, end) => - if (start.notFollows(end)) { - None - } else { - val reason = - s"The end of this Window Frame $end is smaller than the start of " + - s"this Window Frame $start." - Some(reason) - } + lower: Expression, + upper: Expression) + extends WindowFrame { + + override def children: Seq[Expression] = lower :: upper :: Nil + + lazy val valueBoundary: Seq[Expression] = + children.filterNot(_.isInstanceOf[SpecialFrameBoundary]) + + override def checkInputDataTypes(): TypeCheckResult = { + // Check lower value. + val lowerCheck = checkBoundary(lower, "lower") + if (lowerCheck.isFailure) { + return lowerCheck + } + + // Check upper value. + val upperCheck = checkBoundary(upper, "upper") + if (upperCheck.isFailure) { + return upperCheck + } + + // Check combination (of expressions). + (lower, upper) match { + case (l: Expression, u: Expression) if !isValidFrameBoundary(l, u) => + TypeCheckFailure(s"Window frame upper bound '$upper' does not followes the lower bound " + + s"'$lower'.") + case (l: SpecialFrameBoundary, _) => TypeCheckSuccess + case (_, u: SpecialFrameBoundary) => TypeCheckSuccess + case (l: Expression, u: Expression) if l.dataType != u.dataType => + TypeCheckFailure( + s"Window frame bounds '$lower' and '$upper' do no not have the same data type: " + + s"'${l.dataType.catalogString}' <> '${u.dataType.catalogString}'") + case (l: Expression, u: Expression) if isGreaterThan(l, u) => + TypeCheckFailure( + "The lower bound of a window frame must be less than or equal to the upper bound") + case _ => TypeCheckSuccess + } + } + + override def sql: String = { + val lowerSql = boundarySql(lower) + val upperSql = boundarySql(upper) + s"${frameType.sql} BETWEEN $lowerSql AND $upperSql" } - override def toString: String = frameType match { - case RowFrame => s"ROWS BETWEEN $frameStart AND $frameEnd" - case RangeFrame => s"RANGE BETWEEN $frameStart AND $frameEnd" + def isUnbounded: Boolean = lower == UnboundedPreceding && upper == UnboundedFollowing + + def isValueBound: Boolean = valueBoundary.nonEmpty + + def isOffset: Boolean = (lower, upper) match { + case (l: Expression, u: Expression) => frameType == RowFrame && l == u + case _ => false + } + + private def boundarySql(expr: Expression): String = expr match { + case e: SpecialFrameBoundary => e.sql + case UnaryMinus(n) => n.sql + " PRECEDING" + case e: Expression => e.sql + " FOLLOWING" + } + + private def isGreaterThan(l: Expression, r: Expression): Boolean = { + GreaterThan(l, r).eval().asInstanceOf[Boolean] + } + + private def checkBoundary(b: Expression, location: String): TypeCheckResult = b match { + case _: SpecialFrameBoundary => TypeCheckSuccess + case e: Expression if !e.foldable => + TypeCheckFailure(s"Window frame $location bound '$e' is not a literal.") + case e: Expression if !frameType.inputType.acceptsType(e.dataType) => + TypeCheckFailure( + s"The data type of the $location bound '${e.dataType} does not match " + + s"the expected data type '${frameType.inputType}'.") + case _ => TypeCheckSuccess + } + + private def isValidFrameBoundary(l: Expression, u: Expression): Boolean = { + (l, u) match { + case (UnboundedFollowing, _) => false + case (_, UnboundedPreceding) => false + case _ => true + } } } object SpecifiedWindowFrame { /** - * * @param hasOrderSpecification If the window spec has order by expressions. * @param acceptWindowFrame If the window function accepts user-specified frame. - * @return + * @return the default window frame. */ def defaultWindowFrame( hasOrderSpecification: Boolean, @@ -351,20 +356,25 @@ abstract class OffsetWindowFunction override def nullable: Boolean = default == null || default.nullable || input.nullable - override lazy val frame = { - // This will be triggered by the Analyzer. - val offsetValue = offset.eval() match { - case o: Int => o - case x => throw new AnalysisException( - s"Offset expression must be a foldable integer expression: $x") - } + override lazy val frame: WindowFrame = { val boundary = direction match { - case Ascending => ValueFollowing(offsetValue) - case Descending => ValuePreceding(offsetValue) + case Ascending => offset + case Descending => UnaryMinus(offset) } SpecifiedWindowFrame(RowFrame, boundary, boundary) } + override def checkInputDataTypes(): TypeCheckResult = { + val check = super.checkInputDataTypes() + if (check.isFailure) { + check + } else if (!offset.foldable) { + TypeCheckFailure(s"Offset expression '$offset' must be a literal.") + } else { + TypeCheckSuccess + } + } + override def dataType: DataType = input.dataType override def inputTypes: Seq[AbstractDataType] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 45c1d3d430e0d..07578261781b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -750,20 +750,28 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging /** * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different - * hooks. + * hooks. We could add alias names for output columns, for example: + * {{{ + * SELECT col1, col2 FROM testData AS t(col1, col2) + * }}} */ override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { - val alias = if (ctx.strictIdentifier == null) { + val alias = if (ctx.tableAlias.strictIdentifier == null) { // For un-aliased subqueries, use a default alias name that is not likely to conflict with // normal subquery names, so that parent operators can only access the columns in subquery by // unqualified names. Users can still use this special qualifier to access columns if they // know it, but that's not recommended. "__auto_generated_subquery_name" } else { - ctx.strictIdentifier.getText + ctx.tableAlias.strictIdentifier.getText + } + val subquery = SubqueryAlias(alias, plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample)) + if (ctx.tableAlias.identifierList != null) { + val columnAliases = visitIdentifierList(ctx.tableAlias.identifierList) + UnresolvedSubqueryColumnAliases(columnAliases, subquery) + } else { + subquery } - - SubqueryAlias(alias, plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample)) } /** @@ -1179,32 +1187,26 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } /** - * Create or resolve a [[FrameBoundary]]. Simple math expressions are allowed for Value - * Preceding/Following boundaries. These expressions must be constant (foldable) and return an - * integer value. + * Create or resolve a frame boundary expressions. */ - override def visitFrameBound(ctx: FrameBoundContext): FrameBoundary = withOrigin(ctx) { - // We currently only allow foldable integers. - def value: Int = { + override def visitFrameBound(ctx: FrameBoundContext): Expression = withOrigin(ctx) { + def value: Expression = { val e = expression(ctx.expression) - validate(e.resolved && e.foldable && e.dataType == IntegerType, - "Frame bound value must be a constant integer.", - ctx) - e.eval().asInstanceOf[Int] + validate(e.resolved && e.foldable, "Frame bound value must be a literal.", ctx) + e } - // Create the FrameBoundary ctx.boundType.getType match { case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => UnboundedPreceding case SqlBaseParser.PRECEDING => - ValuePreceding(value) + UnaryMinus(value) case SqlBaseParser.CURRENT => CurrentRow case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => UnboundedFollowing case SqlBaseParser.FOLLOWING => - ValueFollowing(value) + value } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 7375a0bcbae75..b6889f21cc6ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -688,8 +688,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case id: FunctionIdentifier => true case spec: BucketSpec => true case catalog: CatalogTable => true - case boundary: FrameBoundary => true - case frame: WindowFrame => true case partition: Partitioning => true case resource: FunctionResource => true case broadcast: BroadcastMode => true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 7311dc3899e53..4e0613619add6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -190,7 +190,7 @@ class AnalysisErrorSuite extends AnalysisTest { WindowSpecDefinition( UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, - SpecifiedWindowFrame(RangeFrame, ValueFollowing(1), ValueFollowing(2)))).as('window)), + SpecifiedWindowFrame(RangeFrame, Literal(1), Literal(2)))).as('window)), "window frame" :: "must match the required frame" :: Nil) errorTest( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index be26b1b26f175..9bcf4773fa903 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -470,4 +470,24 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { Seq("Number of column aliases does not match number of columns. Table name: TaBlE3; " + "number of column aliases: 5; number of columns: 4.")) } + + test("SPARK-20962 Support subquery column aliases in FROM clause") { + def tableColumnsWithAliases(outputNames: Seq[String]): LogicalPlan = { + UnresolvedSubqueryColumnAliases( + outputNames, + SubqueryAlias( + "t", + UnresolvedRelation(TableIdentifier("TaBlE3"))) + ).select(star()) + } + assertAnalysisSuccess(tableColumnsWithAliases("col1" :: "col2" :: "col3" :: "col4" :: Nil)) + assertAnalysisError( + tableColumnsWithAliases("col1" :: Nil), + Seq("Number of column aliases does not match number of columns. " + + "Number of column aliases: 1; number of columns: 4.")) + assertAnalysisError( + tableColumnsWithAliases("col1" :: "col2" :: "col3" :: "col4" :: "col5" :: Nil), + Seq("Number of column aliases does not match number of columns. " + + "Number of column aliases: 5; number of columns: 4.")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index b3994ab0828ad..d62e3b6dfe34f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1109,6 +1109,42 @@ class TypeCoercionSuite extends AnalysisTest { EqualTo(Literal(Array(1, 2)), Literal("123")), EqualTo(Literal(Array(1, 2)), Literal("123"))) } + + test("cast WindowFrame boundaries to the type they operate upon") { + // Can cast frame boundaries to order dataType. + ruleTest(WindowFrameCoercion, + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, Literal(3), Literal(2147483648L))), + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, Cast(3, LongType), Literal(2147483648L))) + ) + // Cannot cast frame boundaries to order dataType. + ruleTest(WindowFrameCoercion, + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal.default(DateType), Ascending)), + SpecifiedWindowFrame(RangeFrame, Literal(10.0), Literal(2147483648L))), + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal.default(DateType), Ascending)), + SpecifiedWindowFrame(RangeFrame, Literal(10.0), Literal(2147483648L))) + ) + // Should not cast SpecialFrameBoundary. + ruleTest(WindowFrameCoercion, + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing)), + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing)) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 45f9f72dccc45..76c79b3d0760c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -267,16 +267,17 @@ class ExpressionParserSuite extends PlanTest { // Range/Row val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) val boundaries = Seq( - ("10 preceding", ValuePreceding(10), CurrentRow), - ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis + ("10 preceding", -Literal(10), CurrentRow), + ("2147483648 preceding", -Literal(2147483648L), CurrentRow), + ("3 + 1 following", Add(Literal(3), Literal(1)), CurrentRow), ("unbounded preceding", UnboundedPreceding, CurrentRow), ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), ("between unbounded preceding and unbounded following", UnboundedPreceding, UnboundedFollowing), - ("between 10 preceding and current row", ValuePreceding(10), CurrentRow), - ("between current row and 5 following", CurrentRow, ValueFollowing(5)), - ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5)) + ("between 10 preceding and current row", -Literal(10), CurrentRow), + ("between current row and 5 following", CurrentRow, Literal(5)), + ("between 10 preceding and 5 following", -Literal(10), Literal(5)) ) frameTypes.foreach { case (frameTypeSql, frameType) => @@ -288,13 +289,9 @@ class ExpressionParserSuite extends PlanTest { } } - // We cannot use non integer constants. - intercept("foo(*) over (partition by a order by b rows 10.0 preceding)", - "Frame bound value must be a constant integer.") - // We cannot use an arbitrary expression. intercept("foo(*) over (partition by a order by b rows exp(b) preceding)", - "Frame bound value must be a constant integer.") + "Frame bound value must be a literal.") } test("row constructor") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 6dad097041a15..c7f39ae18162e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedTableValuedFunction} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -243,7 +243,7 @@ class PlanParserSuite extends AnalysisTest { val sql = "select * from t" val plan = table("t").select(star()) val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc), - SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1))) + SpecifiedWindowFrame(RowFrame, -Literal(1), Literal(1))) // Test window resolution. val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec) @@ -495,6 +495,17 @@ class PlanParserSuite extends AnalysisTest { .select(star())) } + test("SPARK-20962 Support subquery column aliases in FROM clause") { + assertEqual( + "SELECT * FROM (SELECT a AS x, b AS y FROM t) t(col1, col2)", + UnresolvedSubqueryColumnAliases( + Seq("col1", "col2"), + SubqueryAlias( + "t", + UnresolvedRelation(TableIdentifier("t")).select('a.as("x"), 'b.as("y"))) + ).select(star())) + } + test("inline table") { assertEqual("values 1, 2, 3, 4", UnresolvedInlineTable(Seq("col1"), Seq(1, 2, 3, 4).map(x => Seq(Literal(x))))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 819078218c546..4fc947a88f6b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -436,21 +436,22 @@ class TreeNodeSuite extends SparkFunSuite { "bucketColumnNames" -> "[bucket]", "sortColumnNames" -> "[sort]")) - // Converts FrameBoundary to JSON - assertJSON( - ValueFollowing(3), - JObject( - "product-class" -> classOf[ValueFollowing].getName, - "value" -> 3)) - // Converts WindowFrame to JSON assertJSON( - SpecifiedWindowFrame(RowFrame, UnboundedFollowing, CurrentRow), - JObject( - "product-class" -> classOf[SpecifiedWindowFrame].getName, - "frameType" -> JObject("object" -> JString(RowFrame.getClass.getName)), - "frameStart" -> JObject("object" -> JString(UnboundedFollowing.getClass.getName)), - "frameEnd" -> JObject("object" -> JString(CurrentRow.getClass.getName)))) + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow), + List( + JObject( + "class" -> classOf[SpecifiedWindowFrame].getName, + "num-children" -> 2, + "frameType" -> JObject("object" -> JString(RowFrame.getClass.getName)), + "lower" -> 0, + "upper" -> 1), + JObject( + "class" -> UnboundedPreceding.getClass.getName, + "num-children" -> 0), + JObject( + "class" -> CurrentRow.getClass.getName, + "num-children" -> 0))) // Converts Partitioning to JSON assertJSON( diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index ee5bcfd02c79e..6aa52f1aae048 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -19,6 +19,7 @@ import javax.annotation.Nullable; import java.io.IOException; +import java.util.function.Supplier; import com.google.common.annotations.VisibleForTesting; @@ -76,7 +77,8 @@ public UnsafeKVExternalSorter( prefixComputer = SortPrefixUtils.createPrefixGenerator(keySchema); PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema); BaseOrdering ordering = GenerateOrdering.create(keySchema); - KVComparator recordComparator = new KVComparator(ordering, keySchema.length()); + Supplier comparatorSupplier = + () -> new KVComparator(ordering, keySchema.length()); boolean canUseRadixSort = keySchema.length() == 1 && SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0)); @@ -88,7 +90,7 @@ public UnsafeKVExternalSorter( blockManager, serializerManager, taskContext, - recordComparator, + comparatorSupplier, prefixComparator, SparkEnv.get().conf().getInt("spark.shuffle.sort.initialBufferSize", UnsafeExternalRowSorter.DEFAULT_INITIAL_SORT_BUFFER_SIZE), @@ -104,7 +106,11 @@ public UnsafeKVExternalSorter( // as the underlying array for in-memory sorter (it's always large enough). // Since we will not grow the array, it's fine to pass `null` as consumer. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( - null, taskMemoryManager, recordComparator, prefixComparator, map.getArray(), + null, + taskMemoryManager, + comparatorSupplier.get(), + prefixComparator, + map.getArray(), canUseRadixSort); // We cannot use the destructive iterator here because we are reusing the existing memory @@ -137,7 +143,7 @@ public UnsafeKVExternalSorter( blockManager, serializerManager, taskContext, - new KVComparator(ordering, keySchema.length()), + comparatorSupplier, prefixComparator, SparkEnv.get().conf().getInt("spark.shuffle.sort.initialBufferSize", UnsafeExternalRowSorter.DEFAULT_INITIAL_SORT_BUFFER_SIZE), @@ -227,10 +233,8 @@ private static final class KVComparator extends RecordComparator { private final BaseOrdering ordering; private final UnsafeRow row1; private final UnsafeRow row2; - private final int numKeyFields; KVComparator(BaseOrdering ordering, int numKeyFields) { - this.numKeyFields = numKeyFields; this.row1 = new UnsafeRow(numKeyFields); this.row2 = new UnsafeRow(numKeyFields); this.ordering = ordering; @@ -238,10 +242,10 @@ private static final class KVComparator extends RecordComparator { @Override public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { - // Note that since ordering doesn't need the total length of the record, we just pass -1 + // Note that since ordering doesn't need the total length of the record, we just pass 0 // into the row. - row1.pointTo(baseObj1, baseOff1 + 4, -1); - row2.pointTo(baseObj2, baseOff2 + 4, -1); + row1.pointTo(baseObj1, baseOff1 + 4, 0); + row2.pointTo(baseObj2, baseOff2 + 4, 0); return ordering.compare(row1, row2); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 31dea6ad31b12..59d66c599c518 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -322,7 +322,7 @@ public ArrowColumnVector(ValueVector vector) { anyNullsSet = numNulls > 0; } - private static abstract class ArrowVectorAccessor { + private abstract static class ArrowVectorAccessor { private final ValueVector vector; private final ValueVector.Accessor nulls; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 255c4064eb574..0fcda46c9b3c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -499,7 +499,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
    *
  • `compression` (default is the value specified in `spark.sql.parquet.compression.codec`): * compression codec to use when saving to file. This can be one of the known case-insensitive - * shorten names(none, `snappy`, `gzip`, and `lzo`). This will override + * shorten names(`none`, `snappy`, `gzip`, and `lzo`). This will override * `spark.sql.parquet.compression.codec`.
  • *
* diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 71ab0ddf2d6f4..aa968d8b3c34d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -27,6 +27,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils +import org.apache.spark.TaskContext import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ @@ -1107,7 +1108,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def sort(sortCol: String, sortCols: String*): Dataset[T] = { - sort((sortCol +: sortCols).map(apply) : _*) + sort((sortCol +: sortCols).map(Column(_)) : _*) } /** @@ -3090,7 +3091,8 @@ class Dataset[T] private[sql]( val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch queryExecution.toRdd.mapPartitionsInternal { iter => - ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch) + val context = TaskContext.get() + ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch, context) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index c913efe52a41c..240f38f5bfeb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -20,18 +20,13 @@ package org.apache.spark.sql.execution.arrow import java.io.ByteArrayOutputStream import java.nio.channels.Channels -import scala.collection.JavaConverters._ - -import io.netty.buffer.ArrowBuf -import org.apache.arrow.memory.{BufferAllocator, RootAllocator} +import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.BaseValueVector.BaseMutator import org.apache.arrow.vector.file._ -import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} -import org.apache.arrow.vector.types.FloatingPointPrecision -import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.arrow.vector.schema.ArrowRecordBatch import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel +import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -55,19 +50,6 @@ private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Se def asPythonSerializable: Array[Byte] = payload } -private[sql] object ArrowPayload { - - /** - * Create an ArrowPayload from an ArrowRecordBatch and Spark schema. - */ - def apply( - batch: ArrowRecordBatch, - schema: StructType, - allocator: BufferAllocator): ArrowPayload = { - new ArrowPayload(ArrowConverters.batchToByteArray(batch, schema, allocator)) - } -} - private[sql] object ArrowConverters { /** @@ -77,95 +59,55 @@ private[sql] object ArrowConverters { private[sql] def toPayloadIterator( rowIter: Iterator[InternalRow], schema: StructType, - maxRecordsPerBatch: Int): Iterator[ArrowPayload] = { - new Iterator[ArrowPayload] { - private val _allocator = new RootAllocator(Long.MaxValue) - private var _nextPayload = if (rowIter.nonEmpty) convert() else null + maxRecordsPerBatch: Int, + context: TaskContext): Iterator[ArrowPayload] = { - override def hasNext: Boolean = _nextPayload != null - - override def next(): ArrowPayload = { - val obj = _nextPayload - if (hasNext) { - if (rowIter.hasNext) { - _nextPayload = convert() - } else { - _allocator.close() - _nextPayload = null - } - } - obj - } - - private def convert(): ArrowPayload = { - val batch = internalRowIterToArrowBatch(rowIter, schema, _allocator, maxRecordsPerBatch) - ArrowPayload(batch, schema, _allocator) - } - } - } + val arrowSchema = ArrowUtils.toArrowSchema(schema) + val allocator = + ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue) - /** - * Iterate over InternalRows and write to an ArrowRecordBatch, stopping when rowIter is consumed - * or the number of records in the batch equals maxRecordsInBatch. If maxRecordsPerBatch is 0, - * then rowIter will be fully consumed. - */ - private def internalRowIterToArrowBatch( - rowIter: Iterator[InternalRow], - schema: StructType, - allocator: BufferAllocator, - maxRecordsPerBatch: Int = 0): ArrowRecordBatch = { + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val arrowWriter = ArrowWriter.create(root) - val columnWriters = schema.fields.zipWithIndex.map { case (field, ordinal) => - ColumnWriter(field.dataType, ordinal, allocator).init() - } + var closed = false - val writerLength = columnWriters.length - var recordsInBatch = 0 - while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || recordsInBatch < maxRecordsPerBatch)) { - val row = rowIter.next() - var i = 0 - while (i < writerLength) { - columnWriters(i).write(row) - i += 1 + context.addTaskCompletionListener { _ => + if (!closed) { + root.close() + allocator.close() } - recordsInBatch += 1 } - val (fieldNodes, bufferArrays) = columnWriters.map(_.finish()).unzip - val buffers = bufferArrays.flatten - - val rowLength = if (fieldNodes.nonEmpty) fieldNodes.head.getLength else 0 - val recordBatch = new ArrowRecordBatch(rowLength, - fieldNodes.toList.asJava, buffers.toList.asJava) + new Iterator[ArrowPayload] { - buffers.foreach(_.release()) - recordBatch - } + override def hasNext: Boolean = rowIter.hasNext || { + root.close() + allocator.close() + closed = true + false + } - /** - * Convert an ArrowRecordBatch to a byte array and close batch to release resources. Once closed, - * the batch can no longer be used. - */ - private[arrow] def batchToByteArray( - batch: ArrowRecordBatch, - schema: StructType, - allocator: BufferAllocator): Array[Byte] = { - val arrowSchema = ArrowUtils.toArrowSchema(schema) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val out = new ByteArrayOutputStream() - val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) + override def next(): ArrowPayload = { + val out = new ByteArrayOutputStream() + val writer = new ArrowFileWriter(root, null, Channels.newChannel(out)) + + Utils.tryWithSafeFinally { + var rowCount = 0 + while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + val row = rowIter.next() + arrowWriter.write(row) + rowCount += 1 + } + arrowWriter.finish() + writer.writeBatch() + } { + arrowWriter.reset() + writer.close() + } - // Write a batch to byte stream, ensure the batch, allocator and writer are closed - Utils.tryWithSafeFinally { - val loader = new VectorLoader(root) - loader.load(batch) - writer.writeBatch() // writeBatch can throw IOException - } { - batch.close() - root.close() - writer.close() + new ArrowPayload(out.toByteArray) + } } - out.toByteArray } /** @@ -188,214 +130,3 @@ private[sql] object ArrowConverters { } } } - -/** - * Interface for writing InternalRows to Arrow Buffers. - */ -private[arrow] trait ColumnWriter { - def init(): this.type - def write(row: InternalRow): Unit - - /** - * Clear the column writer and return the ArrowFieldNode and ArrowBuf. - * This should be called only once after all the data is written. - */ - def finish(): (ArrowFieldNode, Array[ArrowBuf]) -} - -/** - * Base class for flat arrow column writer, i.e., column without children. - */ -private[arrow] abstract class PrimitiveColumnWriter(val ordinal: Int) - extends ColumnWriter { - - def getFieldType(dtype: ArrowType): FieldType = FieldType.nullable(dtype) - - def valueVector: BaseDataValueVector - def valueMutator: BaseMutator - - def setNull(): Unit - def setValue(row: InternalRow): Unit - - protected var count = 0 - protected var nullCount = 0 - - override def init(): this.type = { - valueVector.allocateNew() - this - } - - override def write(row: InternalRow): Unit = { - if (row.isNullAt(ordinal)) { - setNull() - nullCount += 1 - } else { - setValue(row) - } - count += 1 - } - - override def finish(): (ArrowFieldNode, Array[ArrowBuf]) = { - valueMutator.setValueCount(count) - val fieldNode = new ArrowFieldNode(count, nullCount) - val valueBuffers = valueVector.getBuffers(true) - (fieldNode, valueBuffers) - } -} - -private[arrow] class BooleanColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableBitVector - = new NullableBitVector("BooleanValue", getFieldType(dtype), allocator) - override val valueMutator: NullableBitVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, if (row.getBoolean(ordinal)) 1 else 0 ) -} - -private[arrow] class ShortColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableSmallIntVector - = new NullableSmallIntVector("ShortValue", getFieldType(dtype: ArrowType), allocator) - override val valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getShort(ordinal)) -} - -private[arrow] class IntegerColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableIntVector - = new NullableIntVector("IntValue", getFieldType(dtype), allocator) - override val valueMutator: NullableIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getInt(ordinal)) -} - -private[arrow] class LongColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableBigIntVector - = new NullableBigIntVector("LongValue", getFieldType(dtype), allocator) - override val valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getLong(ordinal)) -} - -private[arrow] class FloatColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableFloat4Vector - = new NullableFloat4Vector("FloatValue", getFieldType(dtype), allocator) - override val valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getFloat(ordinal)) -} - -private[arrow] class DoubleColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableFloat8Vector - = new NullableFloat8Vector("DoubleValue", getFieldType(dtype), allocator) - override val valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getDouble(ordinal)) -} - -private[arrow] class ByteColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableUInt1Vector - = new NullableUInt1Vector("ByteValue", getFieldType(dtype), allocator) - override val valueMutator: NullableUInt1Vector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit - = valueMutator.setSafe(count, row.getByte(ordinal)) -} - -private[arrow] class UTF8StringColumnWriter( - dtype: ArrowType, - ordinal: Int, - allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableVarCharVector - = new NullableVarCharVector("UTF8StringValue", getFieldType(dtype), allocator) - override val valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - val str = row.getUTF8String(ordinal) - valueMutator.setSafe(count, str.getByteBuffer, 0, str.numBytes) - } -} - -private[arrow] class BinaryColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableVarBinaryVector - = new NullableVarBinaryVector("BinaryValue", getFieldType(dtype), allocator) - override val valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - val bytes = row.getBinary(ordinal) - valueMutator.setSafe(count, bytes, 0, bytes.length) - } -} - -private[arrow] class DateColumnWriter(dtype: ArrowType, ordinal: Int, allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableDateDayVector - = new NullableDateDayVector("DateValue", getFieldType(dtype), allocator) - override val valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - valueMutator.setSafe(count, row.getInt(ordinal)) - } -} - -private[arrow] class TimeStampColumnWriter( - dtype: ArrowType, - ordinal: Int, - allocator: BufferAllocator) - extends PrimitiveColumnWriter(ordinal) { - override val valueVector: NullableTimeStampMicroVector - = new NullableTimeStampMicroVector("TimeStampValue", getFieldType(dtype), allocator) - override val valueMutator: NullableTimeStampMicroVector#Mutator = valueVector.getMutator - - override def setNull(): Unit = valueMutator.setNull(count) - override def setValue(row: InternalRow): Unit = { - valueMutator.setSafe(count, row.getLong(ordinal)) - } -} - -private[arrow] object ColumnWriter { - - /** - * Create an Arrow ColumnWriter given the type and ordinal of row. - */ - def apply(dataType: DataType, ordinal: Int, allocator: BufferAllocator): ColumnWriter = { - val dtype = ArrowUtils.toArrowType(dataType) - dataType match { - case BooleanType => new BooleanColumnWriter(dtype, ordinal, allocator) - case ShortType => new ShortColumnWriter(dtype, ordinal, allocator) - case IntegerType => new IntegerColumnWriter(dtype, ordinal, allocator) - case LongType => new LongColumnWriter(dtype, ordinal, allocator) - case FloatType => new FloatColumnWriter(dtype, ordinal, allocator) - case DoubleType => new DoubleColumnWriter(dtype, ordinal, allocator) - case ByteType => new ByteColumnWriter(dtype, ordinal, allocator) - case StringType => new UTF8StringColumnWriter(dtype, ordinal, allocator) - case BinaryType => new BinaryColumnWriter(dtype, ordinal, allocator) - case DateType => new DateColumnWriter(dtype, ordinal, allocator) - case TimestampType => new TimeStampColumnWriter(dtype, ordinal, allocator) - case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala new file mode 100644 index 0000000000000..11ba04d2ce9a7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex._ +import org.apache.arrow.vector.util.DecimalUtility + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.types._ + +object ArrowWriter { + + def create(schema: StructType): ArrowWriter = { + val arrowSchema = ArrowUtils.toArrowSchema(schema) + val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) + create(root) + } + + def create(root: VectorSchemaRoot): ArrowWriter = { + val children = root.getFieldVectors().asScala.map { vector => + vector.allocateNew() + createFieldWriter(vector) + } + new ArrowWriter(root, children.toArray) + } + + private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { + val field = vector.getField() + (ArrowUtils.fromArrowField(field), vector) match { + case (BooleanType, vector: NullableBitVector) => new BooleanWriter(vector) + case (ByteType, vector: NullableTinyIntVector) => new ByteWriter(vector) + case (ShortType, vector: NullableSmallIntVector) => new ShortWriter(vector) + case (IntegerType, vector: NullableIntVector) => new IntegerWriter(vector) + case (LongType, vector: NullableBigIntVector) => new LongWriter(vector) + case (FloatType, vector: NullableFloat4Vector) => new FloatWriter(vector) + case (DoubleType, vector: NullableFloat8Vector) => new DoubleWriter(vector) + case (StringType, vector: NullableVarCharVector) => new StringWriter(vector) + case (BinaryType, vector: NullableVarBinaryVector) => new BinaryWriter(vector) + case (ArrayType(_, _), vector: ListVector) => + val elementVector = createFieldWriter(vector.getDataVector()) + new ArrayWriter(vector, elementVector) + case (StructType(_), vector: NullableMapVector) => + val children = (0 until vector.size()).map { ordinal => + createFieldWriter(vector.getChildByOrdinal(ordinal)) + } + new StructWriter(vector, children.toArray) + case (dt, _) => + throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") + } + } +} + +class ArrowWriter( + val root: VectorSchemaRoot, + fields: Array[ArrowFieldWriter]) { + + def schema: StructType = StructType(fields.map { f => + StructField(f.name, f.dataType, f.nullable) + }) + + private var count: Int = 0 + + def write(row: InternalRow): Unit = { + var i = 0 + while (i < fields.size) { + fields(i).write(row, i) + i += 1 + } + count += 1 + } + + def finish(): Unit = { + root.setRowCount(count) + fields.foreach(_.finish()) + } + + def reset(): Unit = { + root.setRowCount(0) + count = 0 + fields.foreach(_.reset()) + } +} + +private[arrow] abstract class ArrowFieldWriter { + + def valueVector: ValueVector + def valueMutator: ValueVector.Mutator + + def name: String = valueVector.getField().getName() + def dataType: DataType = ArrowUtils.fromArrowField(valueVector.getField()) + def nullable: Boolean = valueVector.getField().isNullable() + + def setNull(): Unit + def setValue(input: SpecializedGetters, ordinal: Int): Unit + + private[arrow] var count: Int = 0 + + def write(input: SpecializedGetters, ordinal: Int): Unit = { + if (input.isNullAt(ordinal)) { + setNull() + } else { + setValue(input, ordinal) + } + count += 1 + } + + def finish(): Unit = { + valueMutator.setValueCount(count) + } + + def reset(): Unit = { + valueMutator.reset() + count = 0 + } +} + +private[arrow] class BooleanWriter(val valueVector: NullableBitVector) extends ArrowFieldWriter { + + override def valueMutator: NullableBitVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0) + } +} + +private[arrow] class ByteWriter(val valueVector: NullableTinyIntVector) extends ArrowFieldWriter { + + override def valueMutator: NullableTinyIntVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getByte(ordinal)) + } +} + +private[arrow] class ShortWriter(val valueVector: NullableSmallIntVector) extends ArrowFieldWriter { + + override def valueMutator: NullableSmallIntVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getShort(ordinal)) + } +} + +private[arrow] class IntegerWriter(val valueVector: NullableIntVector) extends ArrowFieldWriter { + + override def valueMutator: NullableIntVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getInt(ordinal)) + } +} + +private[arrow] class LongWriter(val valueVector: NullableBigIntVector) extends ArrowFieldWriter { + + override def valueMutator: NullableBigIntVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getLong(ordinal)) + } +} + +private[arrow] class FloatWriter(val valueVector: NullableFloat4Vector) extends ArrowFieldWriter { + + override def valueMutator: NullableFloat4Vector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getFloat(ordinal)) + } +} + +private[arrow] class DoubleWriter(val valueVector: NullableFloat8Vector) extends ArrowFieldWriter { + + override def valueMutator: NullableFloat8Vector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getDouble(ordinal)) + } +} + +private[arrow] class StringWriter(val valueVector: NullableVarCharVector) extends ArrowFieldWriter { + + override def valueMutator: NullableVarCharVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val utf8 = input.getUTF8String(ordinal) + // todo: for off-heap UTF8String, how to pass in to arrow without copy? + valueMutator.setSafe(count, utf8.getByteBuffer, 0, utf8.numBytes()) + } +} + +private[arrow] class BinaryWriter( + val valueVector: NullableVarBinaryVector) extends ArrowFieldWriter { + + override def valueMutator: NullableVarBinaryVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val bytes = input.getBinary(ordinal) + valueMutator.setSafe(count, bytes, 0, bytes.length) + } +} + +private[arrow] class ArrayWriter( + val valueVector: ListVector, + val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter { + + override def valueMutator: ListVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val array = input.getArray(ordinal) + var i = 0 + valueMutator.startNewValue(count) + while (i < array.numElements()) { + elementWriter.write(array, i) + i += 1 + } + valueMutator.endValue(count, array.numElements()) + } + + override def finish(): Unit = { + super.finish() + elementWriter.finish() + } + + override def reset(): Unit = { + super.reset() + elementWriter.reset() + } +} + +private[arrow] class StructWriter( + val valueVector: NullableMapVector, + children: Array[ArrowFieldWriter]) extends ArrowFieldWriter { + + override def valueMutator: NullableMapVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + var i = 0 + while (i < children.length) { + children(i).setNull() + children(i).count += 1 + i += 1 + } + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + val struct = input.getStruct(ordinal, children.length) + var i = 0 + while (i < struct.numFields) { + children(i).write(struct, i) + i += 1 + } + valueMutator.setIndexDefined(count) + } + + override def finish(): Unit = { + super.finish() + children.foreach(_.finish()) + } + + override def reset(): Unit = { + super.reset() + children.foreach(_.reset()) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index fa4c99c01916f..e0c2e942072c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -369,7 +369,11 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val valueRowBuffer = new Array[Byte](valueSize) ByteStreams.readFully(input, valueRowBuffer, 0, valueSize) val valueRow = new UnsafeRow(valueSchema.fields.length) - valueRow.pointTo(valueRowBuffer, valueSize) + // If valueSize in existing file is not multiple of 8, floor it to multiple of 8. + // This is a workaround for the following: + // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in + // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data + valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8) map.put(keyRow, valueRow) } } @@ -433,7 +437,11 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val valueRowBuffer = new Array[Byte](valueSize) ByteStreams.readFully(input, valueRowBuffer, 0, valueSize) val valueRow = new UnsafeRow(valueSchema.fields.length) - valueRow.pointTo(valueRowBuffer, valueSize) + // If valueSize in existing file is not multiple of 8, floor it to multiple of 8. + // This is a workaround for the following: + // Prior to Spark 2.3 mistakenly append 4 bytes to the value row in + // `RowBasedKeyValueBatch`, which gets persisted into the checkpoint data + valueRow.pointTo(valueRowBuffer, (valueSize / 8) * 8) map.put(keyRow, valueRow) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 1820cb0ef540b..0766e37826cb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.types.IntegerType /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) @@ -109,46 +108,50 @@ case class WindowExec( * * This method uses Code Generation. It can only be used on the executor side. * - * @param frameType to evaluate. This can either be Row or Range based. - * @param offset with respect to the row. + * @param frame to evaluate. This can either be a Row or Range frame. + * @param bound with respect to the row. * @return a bound ordering object. */ - private[this] def createBoundOrdering(frameType: FrameType, offset: Int): BoundOrdering = { - frameType match { - case RangeFrame => - val (exprs, current, bound) = if (offset == 0) { - // Use the entire order expression when the offset is 0. - val exprs = orderSpec.map(_.child) - val buildProjection = () => newMutableProjection(exprs, child.output) - (orderSpec, buildProjection(), buildProjection()) - } else if (orderSpec.size == 1) { - // Use only the first order expression when the offset is non-null. - val sortExpr = orderSpec.head - val expr = sortExpr.child - // Create the projection which returns the current 'value'. - val current = newMutableProjection(expr :: Nil, child.output) - // Flip the sign of the offset when processing the order is descending - val boundOffset = sortExpr.direction match { - case Descending => -offset - case Ascending => offset - } - // Create the projection which returns the current 'value' modified by adding the offset. - val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) - val bound = newMutableProjection(boundExpr :: Nil, child.output) - (sortExpr :: Nil, current, bound) - } else { - sys.error("Non-Zero range offsets are not supported for windows " + - "with multiple order expressions.") + private[this] def createBoundOrdering(frame: FrameType, bound: Expression): BoundOrdering = { + (frame, bound) match { + case (RowFrame, CurrentRow) => + RowBoundOrdering(0) + + case (RowFrame, IntegerLiteral(offset)) => + RowBoundOrdering(offset) + + case (RangeFrame, CurrentRow) => + val ordering = newOrdering(orderSpec, child.output) + RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection) + + case (RangeFrame, offset: Expression) if orderSpec.size == 1 => + // Use only the first order expression when the offset is non-null. + val sortExpr = orderSpec.head + val expr = sortExpr.child + + // Create the projection which returns the current 'value'. + val current = newMutableProjection(expr :: Nil, child.output) + + // Flip the sign of the offset when processing the order is descending + val boundOffset = sortExpr.direction match { + case Descending => UnaryMinus(offset) + case Ascending => offset } + + // Create the projection which returns the current 'value' modified by adding the offset. + val boundExpr = Add(expr, Cast(boundOffset, expr.dataType)) + val bound = newMutableProjection(boundExpr :: Nil, child.output) + // Construct the ordering. This is used to compare the result of current value projection // to the result of bound value projection. This is done manually because we want to use // Code Generation (if it is enabled). - val sortExprs = exprs.zipWithIndex.map { case (e, i) => - SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction) - } - val ordering = newOrdering(sortExprs, Nil) + val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil + val ordering = newOrdering(boundSortExprs, Nil) RangeBoundOrdering(ordering, current, bound) - case RowFrame => RowBoundOrdering(offset) + + case (RangeFrame, _) => + sys.error("Non-Zero range offsets are not supported for windows " + + "with multiple order expressions.") } } @@ -157,13 +160,13 @@ case class WindowExec( * [[WindowExpression]]s and factory function for the WindowFrameFunction. */ private[this] lazy val windowFrameExpressionFactoryPairs = { - type FrameKey = (String, FrameType, Option[Int], Option[Int]) + type FrameKey = (String, FrameType, Expression, Expression) type ExpressionBuffer = mutable.Buffer[Expression] val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] // Add a function and its function to the map for a given frame. def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { - val key = (tpe, fr.frameType, FrameBoundary(fr.frameStart), FrameBoundary(fr.frameEnd)) + val key = (tpe, fr.frameType, fr.lower, fr.upper) val (es, fns) = framedFunctions.getOrElseUpdate( key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) es += e @@ -203,7 +206,7 @@ case class WindowExec( // Create the factory val factory = key match { // Offset Frame - case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h => + case ("OFFSET", _, IntegerLiteral(offset), _) => target: InternalRow => new OffsetWindowFunctionFrame( target, @@ -215,38 +218,38 @@ case class WindowExec( newMutableProjection(expressions, schema, subexpressionEliminationEnabled), offset) + // Entire Partition Frame. + case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) => + target: InternalRow => { + new UnboundedWindowFunctionFrame(target, processor) + } + // Growing Frame. - case ("AGGREGATE", frameType, None, Some(high)) => + case ("AGGREGATE", frameType, UnboundedPreceding, upper) => target: InternalRow => { new UnboundedPrecedingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, high)) + createBoundOrdering(frameType, upper)) } // Shrinking Frame. - case ("AGGREGATE", frameType, Some(low), None) => + case ("AGGREGATE", frameType, lower, UnboundedFollowing) => target: InternalRow => { new UnboundedFollowingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, low)) + createBoundOrdering(frameType, lower)) } // Moving Frame. - case ("AGGREGATE", frameType, Some(low), Some(high)) => + case ("AGGREGATE", frameType, lower, upper) => target: InternalRow => { new SlidingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, low), - createBoundOrdering(frameType, high)) - } - - // Entire Partition Frame. - case ("AGGREGATE", frameType, None, None) => - target: InternalRow => { - new UnboundedWindowFunctionFrame(target, processor) + createBoundOrdering(frameType, lower), + createBoundOrdering(frameType, upper)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index f653890f6c7ba..f8b404de77a4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.sql.Column +import org.apache.spark.sql.{AnalysisException, Column} import org.apache.spark.sql.catalyst.expressions._ /** @@ -123,7 +123,24 @@ class WindowSpec private[sql]( */ // Note: when updating the doc for this method, also update Window.rowsBetween. def rowsBetween(start: Long, end: Long): WindowSpec = { - between(RowFrame, start, end) + val boundaryStart = start match { + case 0 => CurrentRow + case Long.MinValue => UnboundedPreceding + case x if Int.MinValue <= x && x <= Int.MaxValue => Literal(x.toInt) + case x => throw new AnalysisException(s"Boundary start is not a valid integer: $x") + } + + val boundaryEnd = end match { + case 0 => CurrentRow + case Long.MaxValue => UnboundedFollowing + case x if Int.MinValue <= x && x <= Int.MaxValue => Literal(x.toInt) + case x => throw new AnalysisException(s"Boundary end is not a valid integer: $x") + } + + new WindowSpec( + partitionSpec, + orderSpec, + SpecifiedWindowFrame(RowFrame, boundaryStart, boundaryEnd)) } /** @@ -174,28 +191,22 @@ class WindowSpec private[sql]( */ // Note: when updating the doc for this method, also update Window.rangeBetween. def rangeBetween(start: Long, end: Long): WindowSpec = { - between(RangeFrame, start, end) - } - - private def between(typ: FrameType, start: Long, end: Long): WindowSpec = { val boundaryStart = start match { case 0 => CurrentRow case Long.MinValue => UnboundedPreceding - case x if x < 0 => ValuePreceding(-start.toInt) - case x if x > 0 => ValueFollowing(start.toInt) + case x => Literal(x) } val boundaryEnd = end match { case 0 => CurrentRow case Long.MaxValue => UnboundedFollowing - case x if x < 0 => ValuePreceding(-end.toInt) - case x if x > 0 => ValueFollowing(end.toInt) + case x => Literal(x) } new WindowSpec( partitionSpec, orderSpec, - SpecifiedWindowFrame(typ, boundaryStart, boundaryEnd)) + SpecifiedWindowFrame(RangeFrame, boundaryStart, boundaryEnd)) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql index 2b5b692d29ef4..f1461032065ad 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql @@ -23,3 +23,7 @@ SELECT float(1), double(1), decimal(1); SELECT date("2014-04-04"), timestamp(date("2014-04-04")); -- error handling: only one argument SELECT string(1, 2); + +-- SPARK-21555: RuntimeReplaceable used in group by +CREATE TEMPORARY VIEW tempView1 AS VALUES (1, NAMED_STRUCT('col1', 'gamma', 'col2', 'delta')) AS T(id, st); +SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value"); diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-aliases.sql b/sql/core/src/test/resources/sql-tests/inputs/table-aliases.sql index c90a9c7f85587..85481cbbf9377 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-aliases.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-aliases.sql @@ -15,3 +15,6 @@ SELECT * FROM testData AS t(col1); -- Check alias duplication SELECT a AS col1, b AS col2 FROM testData AS t(c, d); + +-- Subquery aliases in FROM clause +SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2); diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql index c800fc3d49891..342e5719e9a60 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/window.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -1,24 +1,44 @@ -- Test data. CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES -(null, "a"), (1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b"), (null, null), (3, null) -AS testData(val, cate); +(null, 1L, "a"), (1, 1L, "a"), (1, 2L, "a"), (2, 2147483650L, "a"), (1, null, "b"), (2, 3L, "b"), +(3, 2147483650L, "b"), (null, null, null), (3, 1L, null) +AS testData(val, val_long, cate); -- RowsBetween SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW) FROM testData ORDER BY cate, val; SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +ROWS BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long; -- RangeBetween SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val RANGE 1 PRECEDING) FROM testData ORDER BY cate, val; SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +RANGE BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long; -- RangeBetween with reverse OrderBy SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +-- Invalid window frame +SELECT val, cate, count(val) OVER(PARTITION BY cate +ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val, cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY current_date +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN 1 FOLLOWING AND 1 PRECEDING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND current_date PRECEDING) FROM testData ORDER BY cate, val; + + -- Window functions SELECT val, cate, max(val) OVER w AS max, diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out index 732b11050f461..e035505f15d28 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 13 +-- Number of queries: 15 -- !query 0 @@ -122,3 +122,19 @@ struct<> -- !query 12 output org.apache.spark.sql.AnalysisException Function string accepts only one argument; line 1 pos 7 + + +-- !query 13 +CREATE TEMPORARY VIEW tempView1 AS VALUES (1, NAMED_STRUCT('col1', 'gamma', 'col2', 'delta')) AS T(id, st) +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value") +-- !query 14 schema +struct +-- !query 14 output +gamma 1 diff --git a/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out b/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out index 7abbcd834a523..4459f3186c77b 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 8 -- !query 0 @@ -61,3 +61,11 @@ struct<> -- !query 6 output org.apache.spark.sql.AnalysisException cannot resolve '`a`' given input columns: [t.c, t.d]; line 1 pos 7 + + +-- !query 7 +SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2) +-- !query 7 schema +struct +-- !query 7 output +1 1 diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index aa5856138ed81..97511068b323c 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -1,11 +1,12 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 11 +-- Number of queries: 19 -- !query 0 CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES -(null, "a"), (1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b"), (null, null), (3, null) -AS testData(val, cate) +(null, 1L, "a"), (1, 1L, "a"), (1, 2L, "a"), (2, 2147483650L, "a"), (1, null, "b"), (2, 3L, "b"), +(3, 2147483650L, "b"), (null, null, null), (3, 1L, null) +AS testData(val, val_long, cate) -- !query 0 schema struct<> -- !query 0 output @@ -47,11 +48,21 @@ NULL a 1 -- !query 3 +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +ROWS BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'ROWS BETWEEN CURRENT ROW AND 2147483648L FOLLOWING' due to data type mismatch: The data type of the upper bound 'LongType does not match the expected data type 'IntegerType'.; line 1 pos 41 + + +-- !query 4 SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val RANGE 1 PRECEDING) FROM testData ORDER BY cate, val --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output NULL NULL 0 3 NULL 1 NULL a 0 @@ -63,12 +74,12 @@ NULL a 0 3 b 2 --- !query 4 +-- !query 5 SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output NULL NULL NULL 3 NULL 3 NULL a NULL @@ -80,12 +91,29 @@ NULL a NULL 3 b 3 --- !query 5 +-- !query 6 +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +RANGE BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long +-- !query 6 schema +struct +-- !query 6 output +NULL NULL NULL +1 NULL 1 +1 a 4 +1 a 4 +2 a 2147483652 +2147483650 a 2147483650 +NULL b NULL +3 b 2147483653 +2147483650 b 2147483650 + + +-- !query 7 SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val --- !query 5 schema +-- !query 7 schema struct --- !query 5 output +-- !query 7 output NULL NULL NULL 3 NULL 3 NULL a NULL @@ -97,7 +125,73 @@ NULL a NULL 3 b 5 --- !query 6 +-- !query 8 +SELECT val, cate, count(val) OVER(PARTITION BY cate +ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +cannot resolve 'ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING' due to data type mismatch: Window frame upper bound '1' does not followes the lower bound 'unboundedfollowing$()'.; line 1 pos 33 + + +-- !query 9 +SELECT val, cate, count(val) OVER(PARTITION BY cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve '(PARTITION BY testdata.`cate` RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: A range window frame cannot be used in an unordered window specification.; line 1 pos 33 + + +-- !query 10 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val, cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve '(PARTITION BY testdata.`cate` ORDER BY testdata.`val` ASC NULLS FIRST, testdata.`cate` ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: A range window frame with value boundaries cannot be used in a window specification with multiple order by expressions: val#x ASC NULLS FIRST,cate#x ASC NULLS FIRST; line 1 pos 33 + + +-- !query 11 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY current_date +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve '(PARTITION BY testdata.`cate` ORDER BY current_date() ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'DateType' used in the order specification does not match the data type 'IntegerType' which is used in the range frame.; line 1 pos 33 + + +-- !query 12 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN 1 FOLLOWING AND 1 PRECEDING) FROM testData ORDER BY cate, val +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve 'RANGE BETWEEN 1 FOLLOWING AND 1 PRECEDING' due to data type mismatch: The lower bound of a window frame must be less than or equal to the upper bound; line 1 pos 33 + + +-- !query 13 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND current_date PRECEDING) FROM testData ORDER BY cate, val +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.catalyst.parser.ParseException + +Frame bound value must be a literal.(line 2, pos 30) + +== SQL == +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND current_date PRECEDING) FROM testData ORDER BY cate, val +------------------------------^^^ + + +-- !query 14 SELECT val, cate, max(val) OVER w AS max, min(val) OVER w AS min, @@ -124,9 +218,9 @@ approx_count_distinct(val) OVER w AS approx_count_distinct FROM testData WINDOW w AS (PARTITION BY cate ORDER BY val) ORDER BY cate, val --- !query 6 schema +-- !query 14 schema struct --- !query 6 output +-- !query 14 output NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 @@ -138,11 +232,11 @@ NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0. 3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 --- !query 7 +-- !query 15 SELECT val, cate, avg(null) OVER(PARTITION BY cate ORDER BY val) FROM testData ORDER BY cate, val --- !query 7 schema +-- !query 15 schema struct --- !query 7 output +-- !query 15 output NULL NULL NULL 3 NULL NULL NULL a NULL @@ -154,20 +248,20 @@ NULL a NULL 3 b NULL --- !query 8 +-- !query 16 SELECT val, cate, row_number() OVER(PARTITION BY cate) FROM testData ORDER BY cate, val --- !query 8 schema +-- !query 16 schema struct<> --- !query 8 output +-- !query 16 output org.apache.spark.sql.AnalysisException Window function row_number() requires window to be ordered, please add ORDER BY clause. For example SELECT row_number()(value_expr) OVER (PARTITION BY window_partition ORDER BY window_ordering) from table; --- !query 9 +-- !query 17 SELECT val, cate, sum(val) OVER(), avg(val) OVER() FROM testData ORDER BY cate, val --- !query 9 schema +-- !query 17 schema struct --- !query 9 output +-- !query 17 output NULL NULL 13 1.8571428571428572 3 NULL 13 1.8571428571428572 NULL a 13 1.8571428571428572 @@ -179,7 +273,7 @@ NULL a 13 1.8571428571428572 3 b 13 1.8571428571428572 --- !query 10 +-- !query 18 SELECT val, cate, first_value(false) OVER w AS first_value, first_value(true, true) OVER w AS first_value_ignore_null, @@ -190,9 +284,9 @@ last_value(false, false) OVER w AS last_value_contain_null FROM testData WINDOW w AS () ORDER BY cate, val --- !query 10 schema +-- !query 18 schema struct --- !query 10 output +-- !query 18 output NULL NULL false true false false true false 3 NULL false true false false true false NULL a false true false false true false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 204858fa29787..9806e57f08744 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -151,6 +151,48 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row(2.0d), Row(2.0d))) } + test("row between should accept integer values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), + (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + df.createOrReplaceTempView("window_table") + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))), + Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + + val e = intercept[AnalysisException]( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) + assert(e.message.contains("Boundary end is not a valid integer: 2147483648")) + } + + test("range between should accept integer/long values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), + (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + df.createOrReplaceTempView("window_table") + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), + Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), + Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) + ) + } + test("aggregation and rows between with unbounded") { val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") df.createOrReplaceTempView("window_table") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 73098cdb92471..40235e32d35da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1304,6 +1304,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(rlike3.count() == 0) } } + + test("SPARK-21538: Attribute resolution inconsistency in Dataset API") { + val df = spark.range(3).withColumnRenamed("id", "x") + val expected = Row(0) :: Row(1) :: Row (2) :: Nil + checkAnswer(df.sort("id"), expected) + checkAnswer(df.sort(col("id")), expected) + checkAnswer(df.sort($"id"), expected) + checkAnswer(df.sort('id), expected) + checkAnswer(df.orderBy("id"), expected) + checkAnswer(df.orderBy(col("id")), expected) + checkAnswer(df.orderBy($"id"), expected) + checkAnswer(df.orderBy('id), expected) + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 50d8e3024598d..d194f58cd1cdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -127,9 +127,10 @@ class UnsafeFixedWidthAggregationMapSuite PAGE_SIZE_BYTES ) val groupKey = InternalRow(UTF8String.fromString("cats")) + val row = map.getAggregationBuffer(groupKey) // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) - assert(map.getAggregationBuffer(groupKey) != null) + assert(row != null) val iter = map.iterator() assert(iter.next()) iter.getKey.getString(0) should be ("cats") @@ -138,7 +139,7 @@ class UnsafeFixedWidthAggregationMapSuite // Modifications to rows retrieved from the map should update the values in the map iter.getValue.setInt(0, 42) - map.getAggregationBuffer(groupKey).getInt(0) should be (42) + row.getInt(0) should be (42) map.free() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 55b465578a42d..4893b52f240ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -32,7 +32,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{BinaryType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -857,6 +857,449 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "nanData-floating_point.json") } + test("array type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "b_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "c_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : true, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | }, { + | "name" : "d_arr", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : true, + | "type" : { + | "name" : "list" + | }, + | "children" : [ { + | "name" : "element", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "OFFSET", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "a_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 2, 4, 4, 5 ], + | "children" : [ { + | "name" : "element", + | "count" : 5, + | "VALIDITY" : [ 1, 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 4, 5 ] + | } ] + | }, { + | "name" : "b_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 0, 1, 0 ], + | "OFFSET" : [ 0, 2, 2, 2, 2 ], + | "children" : [ { + | "name" : "element", + | "count" : 2, + | "VALIDITY" : [ 1, 1 ], + | "DATA" : [ 1, 2 ] + | } ] + | }, { + | "name" : "c_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 2, 4, 4, 5 ], + | "children" : [ { + | "name" : "element", + | "count" : 5, + | "VALIDITY" : [ 1, 1, 1, 0, 1 ], + | "DATA" : [ 1, 2, 3, 0, 5 ] + | } ] + | }, { + | "name" : "d_arr", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 1, 3, 3, 4 ], + | "children" : [ { + | "name" : "element", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "OFFSET" : [ 0, 2, 3, 3, 4 ], + | "children" : [ { + | "name" : "element", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ 1, 2, 3, 5 ] + | } ] + | } ] + | } ] + | } ] + |} + """.stripMargin + + val aArr = Seq(Seq(1, 2), Seq(3, 4), Seq(), Seq(5)) + val bArr = Seq(Some(Seq(1, 2)), None, Some(Seq()), None) + val cArr = Seq(Seq(Some(1), Some(2)), Seq(Some(3), None), Seq(), Seq(Some(5))) + val dArr = Seq(Seq(Seq(1, 2)), Seq(Seq(3), Seq()), Seq(), Seq(Seq(5))) + + val df = aArr.zip(bArr).zip(cArr).zip(dArr).map { + case (((a, b), c), d) => (a, b, c, d) + }.toDF("a_arr", "b_arr", "c_arr", "d_arr") + + collectAndValidate(df, json, "arrayData.json") + } + + test("struct type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "a_struct", + | "nullable" : false, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | }, { + | "name" : "b_struct", + | "nullable" : true, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : false, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | }, { + | "name" : "c_struct", + | "nullable" : false, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : true, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | }, { + | "name" : "d_struct", + | "nullable" : true, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "nested", + | "nullable" : true, + | "type" : { + | "name" : "struct" + | }, + | "children" : [ { + | "name" : "i", + | "nullable" : true, + | "type" : { + | "name" : "int", + | "bitWidth" : 32, + | "isSigned" : true + | }, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | } ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 3, + | "columns" : [ { + | "name" : "a_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "children" : [ { + | "name" : "i", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "DATA" : [ 1, 2, 3 ] + | } ] + | }, { + | "name" : "b_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 1 ], + | "children" : [ { + | "name" : "i", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 1 ], + | "DATA" : [ 1, 2, 3 ] + | } ] + | }, { + | "name" : "c_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 1, 1 ], + | "children" : [ { + | "name" : "i", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 1 ], + | "DATA" : [ 1, 2, 3 ] + | } ] + | }, { + | "name" : "d_struct", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 1 ], + | "children" : [ { + | "name" : "nested", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 0 ], + | "children" : [ { + | "name" : "i", + | "count" : 3, + | "VALIDITY" : [ 1, 0, 0 ], + | "DATA" : [ 1, 2, 0 ] + | } ] + | } ] + | } ] + | } ] + |} + """.stripMargin + + val aStruct = Seq(Row(1), Row(2), Row(3)) + val bStruct = Seq(Row(1), null, Row(3)) + val cStruct = Seq(Row(1), Row(null), Row(3)) + val dStruct = Seq(Row(Row(1)), null, Row(null)) + val data = aStruct.zip(bStruct).zip(cStruct).zip(dStruct).map { + case (((a, b), c), d) => Row(a, b, c, d) + } + + val rdd = sparkContext.parallelize(data) + val schema = new StructType() + .add("a_struct", new StructType().add("i", IntegerType, nullable = false), nullable = false) + .add("b_struct", new StructType().add("i", IntegerType, nullable = false), nullable = true) + .add("c_struct", new StructType().add("i", IntegerType, nullable = true), nullable = false) + .add("d_struct", new StructType().add("nested", new StructType().add("i", IntegerType))) + val df = spark.createDataFrame(rdd, schema) + + collectAndValidate(df, json, "structData.json") + } + test("partitioned DataFrame") { val json1 = s""" @@ -1015,6 +1458,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch) val df = spark.sparkContext.parallelize(1 to totalRecords, 2).toDF("i") val arrowPayloads = df.toArrowPayload.collect() + assert(arrowPayloads.length >= 4) val allocator = new RootAllocator(Long.MaxValue) val arrowRecordBatches = arrowPayloads.map(_.loadBatch(allocator)) var recordCount = 0 @@ -1039,7 +1483,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } runUnsupported { decimalData.toArrowPayload.collect() } - runUnsupported { arrayData.toDF().toArrowPayload.collect() } runUnsupported { mapData.toDF().toArrowPayload.collect() } runUnsupported { complexData.toArrowPayload.collect() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala new file mode 100644 index 0000000000000..e9a629315f5f4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.arrow + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.execution.vectorized.ArrowColumnVector +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class ArrowWriterSuite extends SparkFunSuite { + + test("simple") { + def check(dt: DataType, data: Seq[Any]): Unit = { + val schema = new StructType().add("value", dt, nullable = true) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + data.foreach { datum => + writer.write(InternalRow(datum)) + } + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + data.zipWithIndex.foreach { + case (null, rowId) => assert(reader.isNullAt(rowId)) + case (datum, rowId) => + val value = dt match { + case BooleanType => reader.getBoolean(rowId) + case ByteType => reader.getByte(rowId) + case ShortType => reader.getShort(rowId) + case IntegerType => reader.getInt(rowId) + case LongType => reader.getLong(rowId) + case FloatType => reader.getFloat(rowId) + case DoubleType => reader.getDouble(rowId) + case StringType => reader.getUTF8String(rowId) + case BinaryType => reader.getBinary(rowId) + } + assert(value === datum) + } + + writer.root.close() + } + check(BooleanType, Seq(true, null, false)) + check(ByteType, Seq(1.toByte, 2.toByte, null, 4.toByte)) + check(ShortType, Seq(1.toShort, 2.toShort, null, 4.toShort)) + check(IntegerType, Seq(1, 2, null, 4)) + check(LongType, Seq(1L, 2L, null, 4L)) + check(FloatType, Seq(1.0f, 2.0f, null, 4.0f)) + check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d)) + check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString)) + check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) + } + + test("get multiple") { + def check(dt: DataType, data: Seq[Any]): Unit = { + val schema = new StructType().add("value", dt, nullable = false) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + data.foreach { datum => + writer.write(InternalRow(datum)) + } + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + val values = dt match { + case BooleanType => reader.getBooleans(0, data.size) + case ByteType => reader.getBytes(0, data.size) + case ShortType => reader.getShorts(0, data.size) + case IntegerType => reader.getInts(0, data.size) + case LongType => reader.getLongs(0, data.size) + case FloatType => reader.getFloats(0, data.size) + case DoubleType => reader.getDoubles(0, data.size) + } + assert(values === data) + + writer.root.close() + } + check(BooleanType, Seq(true, false)) + check(ByteType, (0 until 10).map(_.toByte)) + check(ShortType, (0 until 10).map(_.toShort)) + check(IntegerType, (0 until 10)) + check(LongType, (0 until 10).map(_.toLong)) + check(FloatType, (0 until 10).map(_.toFloat)) + check(DoubleType, (0 until 10).map(_.toDouble)) + } + + test("array") { + val schema = new StructType() + .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(ArrayData.toArrayData(Array(1, 2, 3)))) + writer.write(InternalRow(ArrayData.toArrayData(Array(4, 5)))) + writer.write(InternalRow(null)) + writer.write(InternalRow(ArrayData.toArrayData(Array.empty[Int]))) + writer.write(InternalRow(ArrayData.toArrayData(Array(6, null, 8)))) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val array0 = reader.getArray(0) + assert(array0.numElements() === 3) + assert(array0.getInt(0) === 1) + assert(array0.getInt(1) === 2) + assert(array0.getInt(2) === 3) + + val array1 = reader.getArray(1) + assert(array1.numElements() === 2) + assert(array1.getInt(0) === 4) + assert(array1.getInt(1) === 5) + + assert(reader.isNullAt(2)) + + val array3 = reader.getArray(3) + assert(array3.numElements() === 0) + + val array4 = reader.getArray(4) + assert(array4.numElements() === 3) + assert(array4.getInt(0) === 6) + assert(array4.isNullAt(1)) + assert(array4.getInt(2) === 8) + + writer.root.close() + } + + test("nested array") { + val schema = new StructType().add("nested", ArrayType(ArrayType(IntegerType))) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(ArrayData.toArrayData(Array( + ArrayData.toArrayData(Array(1, 2, 3)), + ArrayData.toArrayData(Array(4, 5)), + null, + ArrayData.toArrayData(Array.empty[Int]), + ArrayData.toArrayData(Array(6, null, 8)))))) + writer.write(InternalRow(null)) + writer.write(InternalRow(ArrayData.toArrayData(Array.empty))) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val array0 = reader.getArray(0) + assert(array0.numElements() === 5) + + val array00 = array0.getArray(0) + assert(array00.numElements() === 3) + assert(array00.getInt(0) === 1) + assert(array00.getInt(1) === 2) + assert(array00.getInt(2) === 3) + + val array01 = array0.getArray(1) + assert(array01.numElements() === 2) + assert(array01.getInt(0) === 4) + assert(array01.getInt(1) === 5) + + assert(array0.isNullAt(2)) + + val array03 = array0.getArray(3) + assert(array03.numElements() === 0) + + val array04 = array0.getArray(4) + assert(array04.numElements() === 3) + assert(array04.getInt(0) === 6) + assert(array04.isNullAt(1)) + assert(array04.getInt(2) === 8) + + assert(reader.isNullAt(1)) + + val array2 = reader.getArray(2) + assert(array2.numElements() === 0) + + writer.root.close() + } + + test("struct") { + val schema = new StructType() + .add("struct", new StructType().add("i", IntegerType).add("str", StringType)) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(InternalRow(1, UTF8String.fromString("str1")))) + writer.write(InternalRow(InternalRow(null, null))) + writer.write(InternalRow(null)) + writer.write(InternalRow(InternalRow(4, null))) + writer.write(InternalRow(InternalRow(null, UTF8String.fromString("str5")))) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val struct0 = reader.getStruct(0, 2) + assert(struct0.getInt(0) === 1) + assert(struct0.getUTF8String(1) === UTF8String.fromString("str1")) + + val struct1 = reader.getStruct(1, 2) + assert(struct1.isNullAt(0)) + assert(struct1.isNullAt(1)) + + assert(reader.isNullAt(2)) + + val struct3 = reader.getStruct(3, 2) + assert(struct3.getInt(0) === 4) + assert(struct3.isNullAt(1)) + + val struct4 = reader.getStruct(4, 2) + assert(struct4.isNullAt(0)) + assert(struct4.getUTF8String(1) === UTF8String.fromString("str5")) + + writer.root.close() + } + + test("nested struct") { + val schema = new StructType().add("struct", + new StructType().add("nested", new StructType().add("i", IntegerType).add("str", StringType))) + val writer = ArrowWriter.create(schema) + assert(writer.schema === schema) + + writer.write(InternalRow(InternalRow(InternalRow(1, UTF8String.fromString("str1"))))) + writer.write(InternalRow(InternalRow(InternalRow(null, null)))) + writer.write(InternalRow(InternalRow(null))) + writer.write(InternalRow(null)) + writer.finish() + + val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0)) + + val struct00 = reader.getStruct(0, 1).getStruct(0, 2) + assert(struct00.getInt(0) === 1) + assert(struct00.getUTF8String(1) === UTF8String.fromString("str1")) + + val struct10 = reader.getStruct(1, 1).getStruct(0, 2) + assert(struct10.isNullAt(0)) + assert(struct10.isNullAt(1)) + + val struct2 = reader.getStruct(2, 1) + assert(struct2.isNullAt(0)) + + assert(reader.isNullAt(3)) + + writer.root.close() + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 17589cf44b998..f517bffccdf31 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -103,7 +103,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } val content = -
SQL Statistics
++ +
SQL Statistics ({numStatement})
++
    {table.getOrElse("No statistics have been generated yet.")} @@ -164,7 +164,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } val content = -
    Session Statistics
    ++ +
    Session Statistics ({numBatches})
    ++
      {table.getOrElse("No statistics have been generated yet.")} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala index 149ce1e195111..90f90599d5bf4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala @@ -98,27 +98,27 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { checkSQL( WindowSpecDefinition('a.int :: Nil, Nil, frame), - s"(PARTITION BY `a` $frame)" + s"(PARTITION BY `a` ${frame.sql})" ) checkSQL( WindowSpecDefinition('a.int :: 'b.string :: Nil, Nil, frame), - s"(PARTITION BY `a`, `b` $frame)" + s"(PARTITION BY `a`, `b` ${frame.sql})" ) checkSQL( WindowSpecDefinition(Nil, 'a.int.asc :: Nil, frame), - s"(ORDER BY `a` ASC NULLS FIRST $frame)" + s"(ORDER BY `a` ASC NULLS FIRST ${frame.sql})" ) checkSQL( WindowSpecDefinition(Nil, 'a.int.asc :: 'b.string.desc :: Nil, frame), - s"(ORDER BY `a` ASC NULLS FIRST, `b` DESC NULLS LAST $frame)" + s"(ORDER BY `a` ASC NULLS FIRST, `b` DESC NULLS LAST ${frame.sql})" ) checkSQL( WindowSpecDefinition('a.int :: 'b.string :: Nil, 'c.int.asc :: 'd.string.desc :: Nil, frame), - s"(PARTITION BY `a`, `b` ORDER BY `c` ASC NULLS FIRST, `d` DESC NULLS LAST $frame)" + s"(PARTITION BY `a`, `b` ORDER BY `c` ASC NULLS FIRST, `d` DESC NULLS LAST ${frame.sql})" ) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 905b1c52afa69..b8a5a96faf15c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -164,6 +164,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( /** Clear the old time-to-files mappings along with old RDDs */ protected[streaming] override def clearMetadata(time: Time) { + super.clearMetadata(time) batchTimeToSelectedFiles.synchronized { val oldFiles = batchTimeToSelectedFiles.filter(_._1 < (time - rememberDuration)) batchTimeToSelectedFiles --= oldFiles.keys